{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "vgUXRKpJh1yE"
   },
   "source": [
    "# FedChill - Ci\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "execution": {
     "iopub.execute_input": "2025-09-23T04:14:06.303913Z",
     "iopub.status.busy": "2025-09-23T04:14:06.303050Z",
     "iopub.status.idle": "2025-09-23T04:14:18.375208Z",
     "shell.execute_reply": "2025-09-23T04:14:18.374479Z",
     "shell.execute_reply.started": "2025-09-23T04:14:06.303870Z"
    },
    "id": "_dq_ZRdAh1yG",
    "outputId": "c38ff43a-2f89-4321-d643-5b651b9429aa",
    "trusted": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using device: cuda\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torch.nn.functional as F\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "import numpy as np\n",
    "from torch.utils.data import Dataset, DataLoader, Subset, random_split\n",
    "import matplotlib.pyplot as plt\n",
    "from collections import Counter\n",
    "import copy\n",
    "import random\n",
    "from typing import Dict, List, Tuple\n",
    "import math\n",
    "\n",
    "import pandas as pd\n",
    "torch.manual_seed(42)\n",
    "np.random.seed(42)\n",
    "random.seed(42)\n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "print(f\"Using device: {device}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "2NKqWJJTS7GB"
   },
   "source": [
    "---"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-09-23T04:14:21.898430Z",
     "iopub.status.busy": "2025-09-23T04:14:21.897693Z",
     "iopub.status.idle": "2025-09-23T04:14:21.901701Z",
     "shell.execute_reply": "2025-09-23T04:14:21.901072Z",
     "shell.execute_reply.started": "2025-09-23T04:14:21.898403Z"
    },
    "id": "hHTqcMmPeQ1c",
    "trusted": true
   },
   "outputs": [],
   "source": [
    "NUM_CLASSES = 10 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-09-23T04:14:28.015681Z",
     "iopub.status.busy": "2025-09-23T04:14:28.015392Z",
     "iopub.status.idle": "2025-09-23T04:14:28.032077Z",
     "shell.execute_reply": "2025-09-23T04:14:28.031306Z",
     "shell.execute_reply.started": "2025-09-23T04:14:28.015660Z"
    },
    "id": "gVCZKSAhh1yH",
    "trusted": true
   },
   "outputs": [],
   "source": [
    "class SimpleCNN(nn.Module):\n",
    "    def __init__(self, num_classes=NUM_CLASSES, seed=42):\n",
    "        super(SimpleCNN, self).__init__()\n",
    "\n",
    "        torch.manual_seed(seed)\n",
    "\n",
    "        self.conv1 = nn.Conv2d(3, 64, 3, padding=1)\n",
    "        self.bn1 = nn.BatchNorm2d(64)\n",
    "        self.conv2 = nn.Conv2d(64, 64, 3, padding=1)\n",
    "        self.bn2 = nn.BatchNorm2d(64)\n",
    "\n",
    "        self.conv3 = nn.Conv2d(64, 128, 3, padding=1)\n",
    "        self.bn3 = nn.BatchNorm2d(128)\n",
    "        self.conv4 = nn.Conv2d(128, 128, 3, padding=1)\n",
    "        self.bn4 = nn.BatchNorm2d(128)\n",
    "\n",
    "        self.conv5 = nn.Conv2d(128, 256, 3, padding=1)\n",
    "        self.bn5 = nn.BatchNorm2d(256)\n",
    "        self.conv6 = nn.Conv2d(256, 256, 3, padding=1)\n",
    "        self.bn6 = nn.BatchNorm2d(256)\n",
    "\n",
    "        self.conv7 = nn.Conv2d(256, 512, 3, padding=1)\n",
    "        self.bn7 = nn.BatchNorm2d(512)\n",
    "\n",
    "        self.pool = nn.MaxPool2d(2, 2)\n",
    "        self.adaptive_pool = nn.AdaptiveAvgPool2d((2, 2))  \n",
    "\n",
    "        self.fc1 = nn.Linear(512 * 2 * 2, 1024)\n",
    "        self.bn_fc1 = nn.BatchNorm1d(1024)\n",
    "        self.dropout1 = nn.Dropout(0.5)\n",
    "\n",
    "        self.fc2 = nn.Linear(1024, 512)\n",
    "        self.bn_fc2 = nn.BatchNorm1d(512)\n",
    "        self.dropout2 = nn.Dropout(0.3)\n",
    "\n",
    "        self.fc3 = nn.Linear(512, num_classes)\n",
    "\n",
    "        self._initialize_weights()\n",
    "\n",
    "        torch.manual_seed(torch.initial_seed())\n",
    "\n",
    "    def _initialize_weights(self):\n",
    "        for m in self.modules():\n",
    "            if isinstance(m, nn.Conv2d):\n",
    "                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n",
    "                if m.bias is not None:\n",
    "                    nn.init.constant_(m.bias, 0)\n",
    "            elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):\n",
    "                nn.init.constant_(m.weight, 1)\n",
    "                nn.init.constant_(m.bias, 0)\n",
    "            elif isinstance(m, nn.Linear):\n",
    "                nn.init.normal_(m.weight, 0, 0.01)\n",
    "                nn.init.constant_(m.bias, 0)\n",
    "\n",
    "\n",
    "    def forward(self, x):\n",
    "\n",
    "        x = F.relu(self.bn1(self.conv1(x)))\n",
    "        x = F.relu(self.bn2(self.conv2(x)))\n",
    "        x = self.pool(x)  \n",
    "\n",
    "        x = F.relu(self.bn3(self.conv3(x)))\n",
    "        x = F.relu(self.bn4(self.conv4(x)))\n",
    "        x = self.pool(x) \n",
    "\n",
    "        x = F.relu(self.bn5(self.conv5(x)))\n",
    "        x = F.relu(self.bn6(self.conv6(x)))\n",
    "        x = self.pool(x)  \n",
    "\n",
    "        x = F.relu(self.bn7(self.conv7(x)))\n",
    "        x = self.adaptive_pool(x) \n",
    "\n",
    "        x = x.view(x.size(0), -1)  \n",
    "\n",
    "        x = F.relu(self.bn_fc1(self.fc1(x)))\n",
    "        x = self.dropout1(x)\n",
    "\n",
    "        x = F.relu(self.bn_fc2(self.fc2(x)))\n",
    "        x = self.dropout2(x)\n",
    "\n",
    "        x = self.fc3(x)\n",
    "        return x\n",
    "\n",
    "    def get_features(self, x):\n",
    "\n",
    "        x = F.relu(self.bn1(self.conv1(x)))\n",
    "        x = F.relu(self.bn2(self.conv2(x)))\n",
    "        x = self.pool(x)\n",
    "\n",
    "\n",
    "        x = F.relu(self.bn3(self.conv3(x)))\n",
    "        x = F.relu(self.bn4(self.conv4(x)))\n",
    "        x = self.pool(x)\n",
    "\n",
    "\n",
    "        x = F.relu(self.bn5(self.conv5(x)))\n",
    "        x = F.relu(self.bn6(self.conv6(x)))\n",
    "        x = self.pool(x)\n",
    "\n",
    "\n",
    "        x = F.relu(self.bn7(self.conv7(x)))\n",
    "        x = self.adaptive_pool(x)\n",
    "\n",
    "\n",
    "        x = x.view(x.size(0), -1)\n",
    "        x = F.relu(self.bn_fc1(self.fc1(x)))\n",
    "        x = self.dropout1(x)\n",
    "        x = F.relu(self.bn_fc2(self.fc2(x)))\n",
    "\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-09-23T04:14:30.318916Z",
     "iopub.status.busy": "2025-09-23T04:14:30.318578Z",
     "iopub.status.idle": "2025-09-23T04:14:30.323552Z",
     "shell.execute_reply": "2025-09-23T04:14:30.322783Z",
     "shell.execute_reply.started": "2025-09-23T04:14:30.318892Z"
    },
    "id": "fKFX-gLRddsE",
    "trusted": true
   },
   "outputs": [],
   "source": [
    "BATCH_SIZE = 64\n",
    "LEARNING_RATE = 0.01\n",
    "LOCAL_EPOCHS = 5\n",
    "NUM_OF_CLIENTS = 20\n",
    "COMM_ROUND = 30\n",
    "ALPHA = 0.5  \n",
    "FRAC = 0.2   "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "oJrL1U1Hh1yJ"
   },
   "source": [
    "---"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-09-23T04:14:32.611069Z",
     "iopub.status.busy": "2025-09-23T04:14:32.610781Z",
     "iopub.status.idle": "2025-09-23T04:14:32.626518Z",
     "shell.execute_reply": "2025-09-23T04:14:32.625752Z",
     "shell.execute_reply.started": "2025-09-23T04:14:32.611049Z"
    },
    "id": "CImsSomMh1yH",
    "trusted": true
   },
   "outputs": [],
   "source": [
    "def load_and_partition_data(num_clients=NUM_OF_CLIENTS, alpha=ALPHA, batch_size=BATCH_SIZE, frac=FRAC, rand_seed=42):\n",
    "\n",
    "    transform = transforms.Compose([\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n",
    "    ])\n",
    "\n",
    "\n",
    "    torch.manual_seed(rand_seed)\n",
    "    np.random.seed(rand_seed)\n",
    "\n",
    "    full_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)\n",
    "    test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)\n",
    "\n",
    "    y_train = np.array(full_dataset.targets)\n",
    "    y_test = np.array(test_dataset.targets)\n",
    "\n",
    "    num_classes = NUM_CLASSES\n",
    "    N = len(full_dataset)  \n",
    "    N_test = len(test_dataset) \n",
    "\n",
    "    net_dataidx_map = {}\n",
    "    net_dataidx_map_test = {}\n",
    "\n",
    "    min_size = 0\n",
    "    while min_size < 10: \n",
    "        idx_batch = [[] for _ in range(num_clients)]\n",
    "        idx_batch_test = [[] for _ in range(num_clients)]\n",
    "        for k in range(num_classes):\n",
    "            idx_k = np.where(y_train == k)[0]\n",
    "            idx_k_test = np.where(y_test == k)[0]\n",
    "            np.random.shuffle(idx_k)  \n",
    "            np.random.shuffle(idx_k_test)  \n",
    "            proportions = np.random.dirichlet(np.repeat(alpha, num_clients))\n",
    "            proportions_train = np.array([p * (len(idx_j) < N / num_clients) for p, idx_j in zip(proportions, idx_batch)])\n",
    "            proportions_test = np.array([p * (len(idx_j) < N_test / num_clients) for p, idx_j in zip(proportions, idx_batch_test)])\n",
    "            proportions_train = proportions_train / proportions_train.sum()  \n",
    "            proportions_test = proportions_test / proportions_test.sum()  \n",
    "            proportions_train = (np.cumsum(proportions_train) * len(idx_k)).astype(int)[:-1]\n",
    "            proportions_test = (np.cumsum(proportions_test) * len(idx_k_test)).astype(int)[:-1]\n",
    "            idx_batch = [idx_j + idx.tolist() for idx_j, idx in zip(idx_batch, np.split(idx_k, proportions_train))]\n",
    "            idx_batch_test = [idx_j + idx.tolist() for idx_j, idx in zip(idx_batch_test, np.split(idx_k_test, proportions_test))]\n",
    "        min_size = min([len(idx_j) for idx_j in idx_batch])  \n",
    "\n",
    "    for j in range(num_clients):\n",
    "        np.random.shuffle(idx_batch[j])\n",
    "        np.random.shuffle(idx_batch_test[j])\n",
    "        net_dataidx_map[j] = idx_batch[j]\n",
    "        net_dataidx_map_test[j] = idx_batch_test[j]\n",
    "\n",
    "    client_train_loaders = []\n",
    "    client_val_loaders = []\n",
    "    client_class_distributions = []\n",
    "\n",
    "    for i in range(num_clients):\n",
    "        np.random.seed(rand_seed + i)\n",
    "\n",
    "        num_data = len(net_dataidx_map[i])\n",
    "        frac_num_data = int(frac * num_data)\n",
    "        frac_indices = np.random.choice(num_data, frac_num_data, replace=False)\n",
    "        train_indices = [net_dataidx_map[i][j] for j in frac_indices]\n",
    "\n",
    "        num_data_test = len(net_dataidx_map_test[i])\n",
    "        frac_num_data_test = int(min(2 * frac, 1.0) * num_data_test)  \n",
    "        frac_indices_test = np.random.choice(num_data_test, frac_num_data_test, replace=False)\n",
    "        val_indices = [net_dataidx_map_test[i][j] for j in frac_indices_test]\n",
    "\n",
    "        client_labels = [y_train[idx] for idx in train_indices]\n",
    "        class_counts = Counter(client_labels)\n",
    "        distribution = {cls: class_counts.get(cls, 0) / len(client_labels) if len(client_labels) > 0 else 0 for cls in range(num_classes)}\n",
    "        client_class_distributions.append(distribution)\n",
    "\n",
    "        client_train_dataset = Subset(full_dataset, train_indices)\n",
    "        client_val_dataset = Subset(test_dataset, val_indices)\n",
    "\n",
    "        g_train = torch.Generator().manual_seed(rand_seed + i)\n",
    "        g_val = torch.Generator().manual_seed(rand_seed + i + num_clients)\n",
    "\n",
    "        train_loader = DataLoader(client_train_dataset, batch_size=batch_size,\n",
    "                                 shuffle=True, generator=g_train, drop_last=True)\n",
    "        val_loader = DataLoader(client_val_dataset, batch_size=batch_size,\n",
    "                                shuffle=True, generator=g_val, drop_last=True)\n",
    "\n",
    "        client_train_loaders.append(train_loader)\n",
    "        client_val_loaders.append(val_loader)\n",
    "\n",
    "    g_test = torch.Generator().manual_seed(rand_seed + 2 * num_clients + 1)\n",
    "    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True,\n",
    "                             generator=g_test, num_workers=2)\n",
    "\n",
    "    print(\"Data partitioning complete.\")\n",
    "    torch.manual_seed(rand_seed)\n",
    "    np.random.seed(rand_seed)\n",
    "\n",
    "    return client_train_loaders, client_val_loaders, test_loader, client_class_distributions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "OOr1Qsu_S7GC"
   },
   "source": [
    "## Utility Functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-09-23T04:14:41.715124Z",
     "iopub.status.busy": "2025-09-23T04:14:41.714323Z",
     "iopub.status.idle": "2025-09-23T04:14:41.725568Z",
     "shell.execute_reply": "2025-09-23T04:14:41.724782Z",
     "shell.execute_reply.started": "2025-09-23T04:14:41.715096Z"
    },
    "id": "f2GMhq3Qh1yJ",
    "trusted": true
   },
   "outputs": [],
   "source": [
    "def client_train(model, train_loader, optimizer, epochs, print_flag=True):\n",
    "    model.train()\n",
    "    for epoch in range(epochs):\n",
    "        running_loss = 0.0\n",
    "        for inputs, targets in train_loader:\n",
    "            inputs, targets = inputs.to(device), targets.to(device)\n",
    "            optimizer.zero_grad()\n",
    "            outputs = model(inputs)\n",
    "            loss = F.cross_entropy(outputs, targets)\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            running_loss += loss.item()\n",
    "\n",
    "        if print_flag:\n",
    "            print(f\"Epoch [{epoch+1}/{epochs}], Loss: {running_loss/len(train_loader):.4f}\")\n",
    "\n",
    "def aggregate_models(global_model, client_models, client_indices, train_loaders):\n",
    "    global_dict = global_model.state_dict()\n",
    "\n",
    "    n_samples = [len(train_loaders[idx].dataset) for idx in client_indices]\n",
    "    total_samples = sum(n_samples)\n",
    "\n",
    "    aggregated_dict = {}\n",
    "\n",
    "    for key in global_dict.keys():\n",
    "        aggregated_dict[key] = torch.zeros_like(global_dict[key], dtype=torch.float32)\n",
    "\n",
    "    for i, idx in enumerate(client_indices):\n",
    "        client_dict = client_models[idx].state_dict()\n",
    "        weight = n_samples[i] / total_samples\n",
    "\n",
    "        for key in global_dict.keys():\n",
    "            aggregated_dict[key] += client_dict[key].float() * weight\n",
    "\n",
    "    for key in global_dict.keys():\n",
    "        global_dict[key] = aggregated_dict[key].to(dtype=global_dict[key].dtype)\n",
    "\n",
    "    global_model.load_state_dict(global_dict)\n",
    "\n",
    "\n",
    "def evaluate_model(model, data_loader):\n",
    "\n",
    "    model.eval()\n",
    "    correct = 0\n",
    "    total = 0\n",
    "    with torch.no_grad():\n",
    "        for inputs, targets in data_loader:\n",
    "            inputs, targets = inputs.to(device), targets.to(device)\n",
    "            outputs = model(inputs)\n",
    "            _, predicted = torch.max(outputs, 1)\n",
    "            total += targets.size(0)\n",
    "            correct += (predicted == targets).sum().item()\n",
    "    return 100 * correct / total\n",
    "\n",
    "def client_train_with_temp(model, train_loader, optimizer, epochs, temperature=1.0, print_flag=True):\n",
    "    model.train()\n",
    "    for epoch in range(epochs):\n",
    "        running_loss = 0.0\n",
    "        for inputs, targets in train_loader:\n",
    "            inputs, targets = inputs.to(device), targets.to(device)\n",
    "            optimizer.zero_grad()\n",
    "            outputs = model(inputs)\n",
    "\n",
    "            if temperature != 1.0:\n",
    "                log_probs = F.log_softmax(outputs / temperature, dim=1)\n",
    "                loss = F.nll_loss(log_probs, targets)\n",
    "            else:\n",
    "                loss = F.cross_entropy(outputs, targets)\n",
    "\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            running_loss += loss.item()\n",
    "\n",
    "        if print_flag:\n",
    "            print(f\"Epoch [{epoch+1}/{epochs}], Loss: {running_loss/len(train_loader):.4f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "changed temp adjust factor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-09-23T04:14:47.449226Z",
     "iopub.status.busy": "2025-09-23T04:14:47.448718Z",
     "iopub.status.idle": "2025-09-23T04:14:47.467332Z",
     "shell.execute_reply": "2025-09-23T04:14:47.466445Z",
     "shell.execute_reply.started": "2025-09-23T04:14:47.449203Z"
    },
    "id": "g0f5rDJn68sE",
    "trusted": true
   },
   "outputs": [],
   "source": [
    "class AdaptiveClientHyperparams:\n",
    "    def __init__(self, initial_lr=0.01, min_lr=0.0001, max_lr=0.05,\n",
    "                 initial_temp=0.5, min_temp=0.05, max_temp=1.0,\n",
    "                 patience=2, lr_decay_factor=0.7, temp_adjust_factor=0.95, scaling_factor=3):\n",
    "        self.lr = initial_lr\n",
    "        self.min_lr = min_lr\n",
    "        self.max_lr = max_lr\n",
    "        self.lr_decay_factor = lr_decay_factor\n",
    "        self.scaling_factor = scaling_factor\n",
    "\n",
    "        self.temp = initial_temp\n",
    "        self.min_temp = min_temp\n",
    "        self.max_temp = max_temp\n",
    "        self.temp_adjust_factor = temp_adjust_factor\n",
    "\n",
    "        self.patience = patience\n",
    "        self.performance_history = []\n",
    "        self.stagnation_count = 0\n",
    "        self.improvement_count = 0\n",
    "        self.het_score = 0.0    \n",
    "\n",
    "    def update_temp_from_heterogeneity(self, het_score):\n",
    "        self.het_score = het_score\n",
    "        self.temp = self.max_temp * math.exp(-self.scaling_factor * het_score) \n",
    "        self.temp = min(max(self.temp, self.min_temp), self.max_temp)\n",
    "\n",
    "    def record_performance(self, accuracy):\n",
    "        self.performance_history.append(accuracy)\n",
    "\n",
    "        if len(self.performance_history) >= 3:\n",
    "            if self.performance_history[-1] <= self.performance_history[-2]:\n",
    "                self.stagnation_count += 1\n",
    "                self.improvement_count = 0\n",
    "            else:\n",
    "                self.improvement_count += 1\n",
    "                self.stagnation_count = 0\n",
    "\n",
    "            if self.stagnation_count >= self.patience:\n",
    "                if self.temp > self.min_temp * 1.1:  \n",
    "                    self.temp *= self.temp_adjust_factor\n",
    "                    self.temp = max(self.temp, self.min_temp)\n",
    "                    self.stagnation_count = 0\n",
    "                    return True, \"decrease-temp\", self.temp\n",
    "\n",
    "        return False, \"unchanged\", None\n",
    "\n",
    "    def get_lr(self):\n",
    "        return self.lr\n",
    "\n",
    "    def get_temp(self):\n",
    "        return self.temp\n",
    "\n",
    "\n",
    "def calculate_heterogeneity_score(client_distribution, global_distribution):\n",
    "    score = 0\n",
    "    for cls in range(10):\n",
    "        client_prob = client_distribution.get(cls, 0)\n",
    "        global_prob = global_distribution.get(cls, 1/10)\n",
    "        if client_prob > 0 and global_prob > 0:\n",
    "            ratio = client_prob / global_prob\n",
    "            score += abs(ratio - 1)\n",
    "\n",
    "    score = min(score / 10, 1)\n",
    "    return score\n",
    "\n",
    "def calculate_global_distribution(client_class_distributions):\n",
    "    global_dist = {}\n",
    "    n_clients = len(client_class_distributions)\n",
    "\n",
    "    for cls in range(10):\n",
    "        global_dist[cls] = sum(dist.get(cls, 0) for dist in client_class_distributions) / n_clients\n",
    "\n",
    "    for cls in range(10):\n",
    "        if global_dist[cls] < 0.01:\n",
    "            global_dist[cls] = 0.01\n",
    "\n",
    "    total = sum(global_dist.values())\n",
    "    global_dist = {k: v/total for k, v in global_dist.items()}\n",
    "\n",
    "    return global_dist\n",
    "\n",
    "def client_train(model, train_loader, optimizer, epochs, print_flag=True):\n",
    "    model.train()\n",
    "    for epoch in range(epochs):\n",
    "        running_loss = 0.0\n",
    "        for inputs, targets in train_loader:\n",
    "            inputs, targets = inputs.to(device), targets.to(device)\n",
    "            optimizer.zero_grad()\n",
    "            outputs = model(inputs)\n",
    "            loss = F.cross_entropy(outputs, targets)\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            running_loss += loss.item()\n",
    "\n",
    "        if print_flag:\n",
    "            print(f\"Epoch [{epoch+1}/{epochs}], Loss: {running_loss/len(train_loader):.4f}\")\n",
    "\n",
    "def aggregate_models(global_model, client_models, client_indices, train_loaders):\n",
    "    \n",
    "    global_dict = global_model.state_dict()\n",
    "\n",
    "    n_samples = [len(train_loaders[idx].dataset) for idx in client_indices]\n",
    "    total_samples = sum(n_samples)\n",
    "\n",
    "    aggregated_dict = {}\n",
    "\n",
    "    for key in global_dict.keys():\n",
    "        aggregated_dict[key] = torch.zeros_like(global_dict[key], dtype=torch.float32)\n",
    "\n",
    "    for i, idx in enumerate(client_indices):\n",
    "        client_dict = client_models[idx].state_dict()\n",
    "        weight = n_samples[i] / total_samples\n",
    "\n",
    "        for key in global_dict.keys():\n",
    "            aggregated_dict[key] += client_dict[key].float() * weight\n",
    "\n",
    "    for key in global_dict.keys():\n",
    "        global_dict[key] = aggregated_dict[key].to(dtype=global_dict[key].dtype)\n",
    "\n",
    "    global_model.load_state_dict(global_dict)\n",
    "\n",
    "\n",
    "def evaluate_model(model, data_loader):\n",
    "\n",
    "    model.eval()\n",
    "    correct = 0\n",
    "    total = 0\n",
    "    with torch.no_grad():\n",
    "        for inputs, targets in data_loader:\n",
    "            inputs, targets = inputs.to(device), targets.to(device)\n",
    "            outputs = model(inputs)\n",
    "            _, predicted = torch.max(outputs, 1)\n",
    "            total += targets.size(0)\n",
    "            correct += (predicted == targets).sum().item()\n",
    "    return 100 * correct / total\n",
    "\n",
    "def client_train_with_temp(model, train_loader, optimizer, epochs, temperature=1.0, print_flag=True):\n",
    "    model.train()\n",
    "    for epoch in range(epochs):\n",
    "        running_loss = 0.0\n",
    "        for inputs, targets in train_loader:\n",
    "            inputs, targets = inputs.to(device), targets.to(device)\n",
    "            optimizer.zero_grad()\n",
    "            outputs = model(inputs)\n",
    "\n",
    "            if temperature != 1.0:\n",
    "                log_probs = F.log_softmax(outputs / temperature, dim=1)\n",
    "                loss = F.nll_loss(log_probs, targets)\n",
    "            else:\n",
    "                loss = F.cross_entropy(outputs, targets)\n",
    "\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            running_loss += loss.item()\n",
    "\n",
    "        if print_flag:\n",
    "            print(f\"Epoch [{epoch+1}/{epochs}], Loss: {running_loss/len(train_loader):.4f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "1-V6QoWgA2hB"
   },
   "source": [
    "---"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "6VOgsXc-A2hB"
   },
   "source": [
    "**HYPERPARAMETERS**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-09-23T04:15:10.857678Z",
     "iopub.status.busy": "2025-09-23T04:15:10.857352Z",
     "iopub.status.idle": "2025-09-23T04:15:10.861479Z",
     "shell.execute_reply": "2025-09-23T04:15:10.860883Z",
     "shell.execute_reply.started": "2025-09-23T04:15:10.857655Z"
    },
    "id": "RlQUq8Ht61cg",
    "trusted": true
   },
   "outputs": [],
   "source": [
    "BATCH_SIZE = 64\n",
    "LEARNING_RATE = 0.01\n",
    "LOCAL_EPOCHS = 5\n",
    "NUM_OF_CLIENTS = 20\n",
    "COMM_ROUND = 30\n",
    "ALPHA = 0.5  \n",
    "FRAC = 0.2   "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-09-22T20:25:34.527068Z",
     "iopub.status.busy": "2025-09-22T20:25:34.526368Z",
     "iopub.status.idle": "2025-09-22T20:25:34.534092Z",
     "shell.execute_reply": "2025-09-22T20:25:34.533439Z",
     "shell.execute_reply.started": "2025-09-22T20:25:34.527043Z"
    },
    "id": "48RvxVUtA2hB",
    "trusted": true
   },
   "outputs": [],
   "source": [
    "def run_experiment():\n",
    "    print(f\"\\n\\n{'='*80}\")\n",
    "    print(f\"{'='*80}\\n\")\n",
    "\n",
    "    model_results = {}\n",
    "\n",
    "    scaling_fact = np.arange(0.5, 4.0, 0.5).tolist()\n",
    "\n",
    "    for sf in scaling_fact:\n",
    "        print(f\"\\n\\n{'#'*50}\")\n",
    "        print(f\"# RUNNING FedChill with Scaling Factor {sf}\")\n",
    "        print(f\"{'#'*50}\\n\")\n",
    "\n",
    "        server_acc, client_acc, final_model = run_fedchill(scaling_factor=sf)\n",
    "\n",
    "        avg_client_acc = np.mean(client_acc, axis=0).tolist()\n",
    "        server_acc = server_acc[-1]\n",
    "\n",
    "        model_results[sf] = {\n",
    "            'server_acc': server_acc,\n",
    "            'client_acc': client_acc,\n",
    "            'final_model': final_model\n",
    "        }\n",
    "\n",
    "\n",
    "    return model_results\n",
    "\n",
    "def plot_model_comparison(results):\n",
    "    \"\"\"Bar chart: Server vs Client accuracy across scaling factors.\"\"\"\n",
    "    scaling_factors = list(results.keys())\n",
    "\n",
    "    \n",
    "    # server_acc = [data['server_acc'][-1] for data in results.values()]\n",
    "    # client_acc = [np.mean(data['client_acc'][-1]) if isinstance(data['client_acc'][-1], (list, np.ndarray))\n",
    "                  # else data['client_acc'][-1] for data in results.values()]\n",
    "\n",
    "    x = np.arange(len(scaling_factors))  \n",
    "    width = 0.35  \n",
    "\n",
    "    plt.figure(figsize=(10, 6))\n",
    "    plt.bar(x - width/2, server_acc, width, label='Server Acc (%)')\n",
    "    plt.bar(x + width/2, client_acc, width, label='Average Client Acc (%)')\n",
    "\n",
    "    plt.xticks(x, [str(sf) for sf in scaling_factors])\n",
    "    plt.xlabel(\"Scaling Factor\")\n",
    "    plt.ylabel(\"Accuracy (%)\")\n",
    "    plt.title(\"Server vs Client Accuracy by Scaling Factor\")\n",
    "    plt.legend()\n",
    "    plt.grid(axis='y', linestyle='--', alpha=0.7)\n",
    "\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(\"scaling_factor_comparison.png\")\n",
    "    plt.show()\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-09-23T04:15:47.341577Z",
     "iopub.status.busy": "2025-09-23T04:15:47.340988Z",
     "iopub.status.idle": "2025-09-23T04:15:47.348648Z",
     "shell.execute_reply": "2025-09-23T04:15:47.347957Z",
     "shell.execute_reply.started": "2025-09-23T04:15:47.341545Z"
    },
    "trusted": true
   },
   "outputs": [],
   "source": [
    "def run_experiment():\n",
    "    print(f\"\\n\\n{'='*80}\")\n",
    "    print(f\"{'='*80}\\n\")\n",
    "\n",
    "    model_results = {}\n",
    "\n",
    "    scaling_fact = np.arange(1.5, 4.0, 0.5).tolist()\n",
    "\n",
    "    for sf in scaling_fact:\n",
    "        print(f\"\\n\\n{'#'*50}\")\n",
    "        print(f\"# RUNNING FedChill with Scaling Factor {sf}\")\n",
    "        print(f\"{'#'*50}\\n\")\n",
    "\n",
    "        server_acc, client_acc, final_model = run_fedchill(scaling_factor=sf)\n",
    "\n",
    "        avg_client_acc = np.mean(client_acc, axis=0).tolist()\n",
    "        server_acc = server_acc[-1]\n",
    "\n",
    "        model_results[sf] = {\n",
    "            'server_acc': server_acc,\n",
    "            'client_acc': client_acc,\n",
    "            'final_model': final_model\n",
    "        }\n",
    "\n",
    "\n",
    "    return model_results\n",
    "\n",
    "def plot_model_comparison(results):\n",
    "    \"\"\"Bar chart: Server vs Client accuracy across scaling factors.\"\"\"\n",
    "    scaling_factors = list(results.keys())\n",
    "\n",
    "    \n",
    "    # server_acc = [data['server_acc'][-1] for data in results.values()]\n",
    "    # client_acc = [np.mean(data['client_acc'][-1]) if isinstance(data['client_acc'][-1], (list, np.ndarray))\n",
    "                  # else data['client_acc'][-1] for data in results.values()]\n",
    "\n",
    "    x = np.arange(len(scaling_factors))  \n",
    "    width = 0.35  \n",
    "\n",
    "    plt.figure(figsize=(10, 6))\n",
    "    plt.bar(x - width/2, server_acc, width, label='Server Acc (%)')\n",
    "    plt.bar(x + width/2, client_acc, width, label='Average Client Acc (%)')\n",
    "\n",
    "    plt.xticks(x, [str(sf) for sf in scaling_factors])\n",
    "    plt.xlabel(\"Scaling Factor\")\n",
    "    plt.ylabel(\"Accuracy (%)\")\n",
    "    plt.title(\"Server vs Client Accuracy by Scaling Factor\")\n",
    "    plt.legend()\n",
    "    plt.grid(axis='y', linestyle='--', alpha=0.7)\n",
    "\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(\"scaling_factor_comparison.png\")\n",
    "    plt.show()\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-09-23T04:15:52.991357Z",
     "iopub.status.busy": "2025-09-23T04:15:52.990754Z",
     "iopub.status.idle": "2025-09-23T04:15:53.002690Z",
     "shell.execute_reply": "2025-09-23T04:15:53.001909Z",
     "shell.execute_reply.started": "2025-09-23T04:15:52.991332Z"
    },
    "id": "ZGkxcCH6A2hC",
    "trusted": true
   },
   "outputs": [],
   "source": [
    "def run_fedchill(scaling_factor=3, model_class=SimpleCNN):\n",
    "    print_flag = False\n",
    "\n",
    "    client_train_loaders, client_val_loaders, test_loader, client_class_distributions = load_and_partition_data(\n",
    "        num_clients=NUM_OF_CLIENTS,\n",
    "        alpha=ALPHA,\n",
    "        batch_size=BATCH_SIZE,\n",
    "        frac=FRAC,\n",
    "        rand_seed=42\n",
    "    )\n",
    "    num_clients = len(client_train_loaders)\n",
    "\n",
    "    global_distribution = calculate_global_distribution(client_class_distributions)\n",
    "\n",
    "    client_params = [AdaptiveClientHyperparams(scaling_factor=scaling_factor) for _ in range(num_clients)]\n",
    "\n",
    "    for i, distribution in enumerate(client_class_distributions):\n",
    "        het_score = calculate_heterogeneity_score(distribution, global_distribution)\n",
    "        client_params[i].update_temp_from_heterogeneity(het_score)\n",
    "        print(f\"Client {i} - Het Score: {het_score:.4f}, Initial temp: {client_params[i].get_temp():.4f}\")\n",
    "\n",
    "    global_model = model_class().to(device)\n",
    "\n",
    "    round_server_acc_list = []\n",
    "    round_client_acc_list = [[] for _ in range(num_clients)]\n",
    "\n",
    "    client_temp_history = [[] for _ in range(num_clients)]\n",
    "    client_lr_history = [[] for _ in range(num_clients)]\n",
    "\n",
    "    client_models = [copy.deepcopy(global_model).to(device) for _ in range(num_clients)]\n",
    "    client_optimizers = [optim.SGD(model.parameters(), lr=client_params[i].get_lr()) for i, model in enumerate(client_models)]\n",
    "\n",
    "    for comm_round in range(COMM_ROUND):\n",
    "        print(f\"\\n{'='*20} COMMUNICATION ROUND {comm_round+1} {'='*20}\")\n",
    "\n",
    "        selected_clients = list(range(num_clients))\n",
    "\n",
    "        if comm_round == COMM_ROUND - 1:\n",
    "            print_flag = True\n",
    "            print(\"\\n--- LOCAL TRAINING OF CLIENTS ---\")\n",
    "\n",
    "        for client_idx in selected_clients:\n",
    "            client_models[client_idx].load_state_dict(global_model.state_dict())\n",
    "\n",
    "            curr_lr = client_params[client_idx].get_lr()\n",
    "            curr_temp = client_params[client_idx].get_temp()\n",
    "\n",
    "            client_temp_history[client_idx].append(curr_temp)\n",
    "            client_lr_history[client_idx].append(curr_lr)\n",
    "\n",
    "            client_optimizers[client_idx] = optim.SGD(client_models[client_idx].parameters(), lr=curr_lr)\n",
    "\n",
    "            if comm_round == COMM_ROUND - 1:\n",
    "                print(f\"\\n--- Client {client_idx} Local Training with T={curr_temp:.4f}, LR={curr_lr:.6f} ---\")\n",
    "\n",
    "            client_train_with_temp(\n",
    "                client_models[client_idx],\n",
    "                client_train_loaders[client_idx],\n",
    "                client_optimizers[client_idx],\n",
    "                LOCAL_EPOCHS,\n",
    "                temperature=curr_temp,\n",
    "                print_flag=print_flag\n",
    "            )\n",
    "\n",
    "            local_acc = evaluate_model(client_models[client_idx], client_val_loaders[client_idx])\n",
    "            if comm_round == COMM_ROUND - 1:\n",
    "                print(f\"Client {client_idx} Private Data Validation Accuracy: {local_acc:.2f}%\")\n",
    "\n",
    "            client_testset_acc = evaluate_model(client_models[client_idx], test_loader)\n",
    "            round_client_acc_list[client_idx].append(local_acc)\n",
    "\n",
    "            adjusted, param_type, new_value = client_params[client_idx].record_performance(local_acc)\n",
    "            if adjusted and (comm_round == COMM_ROUND - 1 or print_flag):\n",
    "                print(f\"Client {client_idx} - {param_type} to {new_value:.6f}\")\n",
    "\n",
    "        aggregate_models(global_model, client_models, selected_clients, client_train_loaders)\n",
    "\n",
    "        server_test_val_acc = evaluate_model(global_model, test_loader)\n",
    "        if comm_round == COMM_ROUND - 1:\n",
    "            print(f\"\\n--- Server Evaluation ---\")\n",
    "            print(f\"Server Test Data Validation Accuracy: {server_test_val_acc:.2f}%\")\n",
    "\n",
    "        round_server_acc_list.append(server_test_val_acc)\n",
    "\n",
    "    return round_server_acc_list, round_client_acc_list, global_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-09-22T20:25:54.507779Z",
     "iopub.status.busy": "2025-09-22T20:25:54.507465Z"
    },
    "trusted": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "================================================================================\n",
      "================================================================================\n",
      "\n",
      "\n",
      "\n",
      "##################################################\n",
      "# RUNNING FedChill with Scaling Factor 0.5\n",
      "##################################################\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 170M/170M [00:03<00:00, 43.1MB/s] \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Data partitioning complete.\n",
      "Client 0 - Het Score: 0.8110, Initial temp: 0.6666\n",
      "Client 1 - Het Score: 0.9049, Initial temp: 0.6361\n",
      "Client 2 - Het Score: 0.8625, Initial temp: 0.6497\n",
      "Client 3 - Het Score: 0.8298, Initial temp: 0.6604\n",
      "Client 4 - Het Score: 0.7894, Initial temp: 0.6739\n",
      "Client 5 - Het Score: 0.6072, Initial temp: 0.7382\n",
      "Client 6 - Het Score: 0.7654, Initial temp: 0.6820\n",
      "Client 7 - Het Score: 0.4441, Initial temp: 0.8009\n",
      "Client 8 - Het Score: 0.8104, Initial temp: 0.6668\n",
      "Client 9 - Het Score: 1.0000, Initial temp: 0.6065\n",
      "Client 10 - Het Score: 0.9858, Initial temp: 0.6108\n",
      "Client 11 - Het Score: 0.8155, Initial temp: 0.6651\n",
      "Client 12 - Het Score: 0.6630, Initial temp: 0.7178\n",
      "Client 13 - Het Score: 0.5425, Initial temp: 0.7624\n",
      "Client 14 - Het Score: 0.8720, Initial temp: 0.6466\n",
      "Client 15 - Het Score: 0.7941, Initial temp: 0.6723\n",
      "Client 16 - Het Score: 0.6080, Initial temp: 0.7379\n",
      "Client 17 - Het Score: 0.9110, Initial temp: 0.6341\n",
      "Client 18 - Het Score: 0.8198, Initial temp: 0.6637\n",
      "Client 19 - Het Score: 0.8745, Initial temp: 0.6458\n",
      "\n",
      "==================== COMMUNICATION ROUND 1 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 2 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 3 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 4 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 5 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 6 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 7 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 8 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 9 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 10 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 11 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 12 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 13 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 14 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 15 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 16 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 17 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 18 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 19 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 20 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 21 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 22 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 23 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 24 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 25 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 26 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 27 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 28 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 29 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 30 ====================\n",
      "\n",
      "--- LOCAL TRAINING OF CLIENTS ---\n",
      "\n",
      "--- Client 0 Local Training with T=0.5430, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.5200\n",
      "Epoch [2/5], Loss: 0.2153\n",
      "Epoch [3/5], Loss: 0.0743\n",
      "Epoch [4/5], Loss: 0.0713\n",
      "Epoch [5/5], Loss: 0.0465\n",
      "Client 0 Private Data Validation Accuracy: 69.79%\n",
      "\n",
      "--- Client 1 Local Training with T=0.5454, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.4946\n",
      "Epoch [2/5], Loss: 0.1015\n",
      "Epoch [3/5], Loss: 0.0602\n",
      "Epoch [4/5], Loss: 0.0406\n",
      "Epoch [5/5], Loss: 0.0252\n",
      "Client 1 Private Data Validation Accuracy: 82.29%\n",
      "Client 1 - decrease-temp to 0.518088\n",
      "\n",
      "--- Client 2 Local Training with T=0.5864, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.4635\n",
      "Epoch [2/5], Loss: 0.1430\n",
      "Epoch [3/5], Loss: 0.0648\n",
      "Epoch [4/5], Loss: 0.0565\n",
      "Epoch [5/5], Loss: 0.0347\n",
      "Client 2 Private Data Validation Accuracy: 71.35%\n",
      "\n",
      "--- Client 3 Local Training with T=0.5662, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.5019\n",
      "Epoch [2/5], Loss: 0.1736\n",
      "Epoch [3/5], Loss: 0.0911\n",
      "Epoch [4/5], Loss: 0.0477\n",
      "Epoch [5/5], Loss: 0.0381\n",
      "Client 3 Private Data Validation Accuracy: 67.19%\n",
      "\n",
      "--- Client 4 Local Training with T=0.5489, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.6091\n",
      "Epoch [2/5], Loss: 0.3243\n",
      "Epoch [3/5], Loss: 0.1651\n",
      "Epoch [4/5], Loss: 0.0694\n",
      "Epoch [5/5], Loss: 0.0454\n",
      "Client 4 Private Data Validation Accuracy: 67.97%\n",
      "\n",
      "--- Client 5 Local Training with T=0.5712, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.5217\n",
      "Epoch [2/5], Loss: 0.1866\n",
      "Epoch [3/5], Loss: 0.0835\n",
      "Epoch [4/5], Loss: 0.0524\n",
      "Epoch [5/5], Loss: 0.0403\n",
      "Client 5 Private Data Validation Accuracy: 68.75%\n",
      "\n",
      "--- Client 6 Local Training with T=0.5847, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.3918\n",
      "Epoch [2/5], Loss: 0.1038\n",
      "Epoch [3/5], Loss: 0.0480\n",
      "Epoch [4/5], Loss: 0.0201\n",
      "Epoch [5/5], Loss: 0.0255\n",
      "Client 6 Private Data Validation Accuracy: 71.88%\n",
      "\n",
      "--- Client 7 Local Training with T=0.6866, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.3221\n",
      "Epoch [2/5], Loss: 0.1520\n",
      "Epoch [3/5], Loss: 0.0903\n",
      "Epoch [4/5], Loss: 0.0674\n",
      "Epoch [5/5], Loss: 0.0556\n",
      "Client 7 Private Data Validation Accuracy: 61.72%\n",
      "\n",
      "--- Client 8 Local Training with T=0.5432, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.8927\n",
      "Epoch [2/5], Loss: 0.2230\n",
      "Epoch [3/5], Loss: 0.0980\n",
      "Epoch [4/5], Loss: 0.0526\n",
      "Epoch [5/5], Loss: 0.0503\n",
      "Client 8 Private Data Validation Accuracy: 64.84%\n",
      "\n",
      "--- Client 9 Local Training with T=0.4459, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.3286\n",
      "Epoch [2/5], Loss: 0.0988\n",
      "Epoch [3/5], Loss: 0.0782\n",
      "Epoch [4/5], Loss: 0.0419\n",
      "Epoch [5/5], Loss: 0.0283\n",
      "Client 9 Private Data Validation Accuracy: 78.91%\n",
      "\n",
      "--- Client 10 Local Training with T=0.5513, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.6229\n",
      "Epoch [2/5], Loss: 0.1169\n",
      "Epoch [3/5], Loss: 0.0415\n",
      "Epoch [4/5], Loss: 0.0427\n",
      "Epoch [5/5], Loss: 0.0461\n",
      "Client 10 Private Data Validation Accuracy: 69.92%\n",
      "\n",
      "--- Client 11 Local Training with T=0.6003, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 1.4902\n",
      "Epoch [2/5], Loss: 0.4247\n",
      "Epoch [3/5], Loss: 0.1779\n",
      "Epoch [4/5], Loss: 0.0767\n",
      "Epoch [5/5], Loss: 0.0571\n",
      "Client 11 Private Data Validation Accuracy: 75.00%\n",
      "\n",
      "--- Client 12 Local Training with T=0.6155, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.7090\n",
      "Epoch [2/5], Loss: 0.1321\n",
      "Epoch [3/5], Loss: 0.0844\n",
      "Epoch [4/5], Loss: 0.0556\n",
      "Epoch [5/5], Loss: 0.0458\n",
      "Client 12 Private Data Validation Accuracy: 82.29%\n",
      "\n",
      "--- Client 13 Local Training with T=0.6881, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.5897\n",
      "Epoch [2/5], Loss: 0.2493\n",
      "Epoch [3/5], Loss: 0.1691\n",
      "Epoch [4/5], Loss: 0.0879\n",
      "Epoch [5/5], Loss: 0.0872\n",
      "Client 13 Private Data Validation Accuracy: 62.50%\n",
      "\n",
      "--- Client 14 Local Training with T=0.5544, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 1.0461\n",
      "Epoch [2/5], Loss: 0.3506\n",
      "Epoch [3/5], Loss: 0.2014\n",
      "Epoch [4/5], Loss: 0.1026\n",
      "Epoch [5/5], Loss: 0.0632\n",
      "Client 14 Private Data Validation Accuracy: 62.50%\n",
      "\n",
      "--- Client 15 Local Training with T=0.5476, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.7384\n",
      "Epoch [2/5], Loss: 0.1859\n",
      "Epoch [3/5], Loss: 0.0638\n",
      "Epoch [4/5], Loss: 0.0394\n",
      "Epoch [5/5], Loss: 0.0250\n",
      "Client 15 Private Data Validation Accuracy: 76.04%\n",
      "Client 15 - decrease-temp to 0.520219\n",
      "\n",
      "--- Client 16 Local Training with T=0.5709, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.6125\n",
      "Epoch [2/5], Loss: 0.1665\n",
      "Epoch [3/5], Loss: 0.0780\n",
      "Epoch [4/5], Loss: 0.0481\n",
      "Epoch [5/5], Loss: 0.0341\n",
      "Client 16 Private Data Validation Accuracy: 68.75%\n",
      "\n",
      "--- Client 17 Local Training with T=0.5723, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.8175\n",
      "Epoch [2/5], Loss: 0.3092\n",
      "Epoch [3/5], Loss: 0.1452\n",
      "Epoch [4/5], Loss: 0.1122\n",
      "Epoch [5/5], Loss: 0.0590\n",
      "Client 17 Private Data Validation Accuracy: 64.84%\n",
      "\n",
      "--- Client 18 Local Training with T=0.5406, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.6091\n",
      "Epoch [2/5], Loss: 0.1695\n",
      "Epoch [3/5], Loss: 0.1355\n",
      "Epoch [4/5], Loss: 0.0617\n",
      "Epoch [5/5], Loss: 0.0365\n",
      "Client 18 Private Data Validation Accuracy: 75.78%\n",
      "Client 18 - decrease-temp to 0.513580\n",
      "\n",
      "--- Client 19 Local Training with T=0.5260, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.5940\n",
      "Epoch [2/5], Loss: 0.1899\n",
      "Epoch [3/5], Loss: 0.0901\n",
      "Epoch [4/5], Loss: 0.0316\n",
      "Epoch [5/5], Loss: 0.0388\n",
      "Client 19 Private Data Validation Accuracy: 65.10%\n",
      "Client 19 - decrease-temp to 0.499703\n",
      "\n",
      "--- Server Evaluation ---\n",
      "Server Test Data Validation Accuracy: 57.74%\n",
      "\n",
      "\n",
      "##################################################\n",
      "# RUNNING FedChill with Scaling Factor 1.0\n",
      "##################################################\n",
      "\n",
      "Data partitioning complete.\n",
      "Client 0 - Het Score: 0.8110, Initial temp: 0.4444\n",
      "Client 1 - Het Score: 0.9049, Initial temp: 0.4046\n",
      "Client 2 - Het Score: 0.8625, Initial temp: 0.4221\n",
      "Client 3 - Het Score: 0.8298, Initial temp: 0.4362\n",
      "Client 4 - Het Score: 0.7894, Initial temp: 0.4541\n",
      "Client 5 - Het Score: 0.6072, Initial temp: 0.5449\n",
      "Client 6 - Het Score: 0.7654, Initial temp: 0.4651\n",
      "Client 7 - Het Score: 0.4441, Initial temp: 0.6414\n",
      "Client 8 - Het Score: 0.8104, Initial temp: 0.4447\n",
      "Client 9 - Het Score: 1.0000, Initial temp: 0.3679\n",
      "Client 10 - Het Score: 0.9858, Initial temp: 0.3731\n",
      "Client 11 - Het Score: 0.8155, Initial temp: 0.4424\n",
      "Client 12 - Het Score: 0.6630, Initial temp: 0.5153\n",
      "Client 13 - Het Score: 0.5425, Initial temp: 0.5813\n",
      "Client 14 - Het Score: 0.8720, Initial temp: 0.4181\n",
      "Client 15 - Het Score: 0.7941, Initial temp: 0.4520\n",
      "Client 16 - Het Score: 0.6080, Initial temp: 0.5444\n",
      "Client 17 - Het Score: 0.9110, Initial temp: 0.4021\n",
      "Client 18 - Het Score: 0.8198, Initial temp: 0.4405\n",
      "Client 19 - Het Score: 0.8745, Initial temp: 0.4170\n",
      "\n",
      "==================== COMMUNICATION ROUND 1 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 2 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 3 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 4 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 5 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 6 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 7 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 8 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 9 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 10 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 11 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 12 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 13 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 14 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 15 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 16 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 17 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 18 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 19 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 20 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 21 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 22 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 23 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 24 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 25 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 26 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 27 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 28 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 29 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 30 ====================\n",
      "\n",
      "--- LOCAL TRAINING OF CLIENTS ---\n",
      "\n",
      "--- Client 0 Local Training with T=0.4011, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.3581\n",
      "Epoch [2/5], Loss: 0.1786\n",
      "Epoch [3/5], Loss: 0.0668\n",
      "Epoch [4/5], Loss: 0.0590\n",
      "Epoch [5/5], Loss: 0.0538\n",
      "Client 0 Private Data Validation Accuracy: 75.00%\n",
      "\n",
      "--- Client 1 Local Training with T=0.3295, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.4683\n",
      "Epoch [2/5], Loss: 0.0810\n",
      "Epoch [3/5], Loss: 0.0468\n",
      "Epoch [4/5], Loss: 0.0638\n",
      "Epoch [5/5], Loss: 0.0164\n",
      "Client 1 Private Data Validation Accuracy: 81.25%\n",
      "Client 1 - decrease-temp to 0.313067\n",
      "\n",
      "--- Client 2 Local Training with T=0.3266, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.4078\n",
      "Epoch [2/5], Loss: 0.1648\n",
      "Epoch [3/5], Loss: 0.0653\n",
      "Epoch [4/5], Loss: 0.0575\n",
      "Epoch [5/5], Loss: 0.0321\n",
      "Client 2 Private Data Validation Accuracy: 68.23%\n",
      "\n",
      "--- Client 3 Local Training with T=0.3553, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.4674\n",
      "Epoch [2/5], Loss: 0.1534\n",
      "Epoch [3/5], Loss: 0.0704\n",
      "Epoch [4/5], Loss: 0.0555\n",
      "Epoch [5/5], Loss: 0.0234\n",
      "Client 3 Private Data Validation Accuracy: 69.79%\n",
      "\n",
      "--- Client 4 Local Training with T=0.3338, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.4676\n",
      "Epoch [2/5], Loss: 0.2492\n",
      "Epoch [3/5], Loss: 0.1763\n",
      "Epoch [4/5], Loss: 0.0642\n",
      "Epoch [5/5], Loss: 0.0473\n",
      "Client 4 Private Data Validation Accuracy: 64.06%\n",
      "\n",
      "--- Client 5 Local Training with T=0.4918, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.4634\n",
      "Epoch [2/5], Loss: 0.1407\n",
      "Epoch [3/5], Loss: 0.0740\n",
      "Epoch [4/5], Loss: 0.0496\n",
      "Epoch [5/5], Loss: 0.0354\n",
      "Client 5 Private Data Validation Accuracy: 68.75%\n",
      "\n",
      "--- Client 6 Local Training with T=0.3988, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.3472\n",
      "Epoch [2/5], Loss: 0.0808\n",
      "Epoch [3/5], Loss: 0.0343\n",
      "Epoch [4/5], Loss: 0.0176\n",
      "Epoch [5/5], Loss: 0.0211\n",
      "Client 6 Private Data Validation Accuracy: 72.92%\n",
      "\n",
      "--- Client 7 Local Training with T=0.5788, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.3121\n",
      "Epoch [2/5], Loss: 0.1360\n",
      "Epoch [3/5], Loss: 0.0935\n",
      "Epoch [4/5], Loss: 0.0701\n",
      "Epoch [5/5], Loss: 0.0582\n",
      "Client 7 Private Data Validation Accuracy: 60.55%\n",
      "\n",
      "--- Client 8 Local Training with T=0.3441, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.7977\n",
      "Epoch [2/5], Loss: 0.1854\n",
      "Epoch [3/5], Loss: 0.0760\n",
      "Epoch [4/5], Loss: 0.0342\n",
      "Epoch [5/5], Loss: 0.0303\n",
      "Client 8 Private Data Validation Accuracy: 67.97%\n",
      "\n",
      "--- Client 9 Local Training with T=0.3320, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.2779\n",
      "Epoch [2/5], Loss: 0.0900\n",
      "Epoch [3/5], Loss: 0.0410\n",
      "Epoch [4/5], Loss: 0.0266\n",
      "Epoch [5/5], Loss: 0.0161\n",
      "Client 9 Private Data Validation Accuracy: 75.78%\n",
      "\n",
      "--- Client 10 Local Training with T=0.2887, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.7527\n",
      "Epoch [2/5], Loss: 0.2645\n",
      "Epoch [3/5], Loss: 0.1086\n",
      "Epoch [4/5], Loss: 0.0395\n",
      "Epoch [5/5], Loss: 0.0542\n",
      "Client 10 Private Data Validation Accuracy: 70.70%\n",
      "\n",
      "--- Client 11 Local Training with T=0.3603, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 1.6304\n",
      "Epoch [2/5], Loss: 0.2855\n",
      "Epoch [3/5], Loss: 0.0573\n",
      "Epoch [4/5], Loss: 0.0249\n",
      "Epoch [5/5], Loss: 0.0316\n",
      "Client 11 Private Data Validation Accuracy: 71.88%\n",
      "\n",
      "--- Client 12 Local Training with T=0.4895, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.6284\n",
      "Epoch [2/5], Loss: 0.0780\n",
      "Epoch [3/5], Loss: 0.0651\n",
      "Epoch [4/5], Loss: 0.0521\n",
      "Epoch [5/5], Loss: 0.0310\n",
      "Client 12 Private Data Validation Accuracy: 80.73%\n",
      "\n",
      "--- Client 13 Local Training with T=0.4984, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.5636\n",
      "Epoch [2/5], Loss: 0.2062\n",
      "Epoch [3/5], Loss: 0.1443\n",
      "Epoch [4/5], Loss: 0.0840\n",
      "Epoch [5/5], Loss: 0.0725\n",
      "Client 13 Private Data Validation Accuracy: 64.06%\n",
      "Client 13 - decrease-temp to 0.473451\n",
      "\n",
      "--- Client 14 Local Training with T=0.3406, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 1.1183\n",
      "Epoch [2/5], Loss: 0.3230\n",
      "Epoch [3/5], Loss: 0.1459\n",
      "Epoch [4/5], Loss: 0.0707\n",
      "Epoch [5/5], Loss: 0.0666\n",
      "Client 14 Private Data Validation Accuracy: 60.94%\n",
      "\n",
      "--- Client 15 Local Training with T=0.3682, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.7914\n",
      "Epoch [2/5], Loss: 0.1649\n",
      "Epoch [3/5], Loss: 0.0588\n",
      "Epoch [4/5], Loss: 0.0443\n",
      "Epoch [5/5], Loss: 0.0208\n",
      "Client 15 Private Data Validation Accuracy: 79.17%\n",
      "\n",
      "--- Client 16 Local Training with T=0.5172, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.4942\n",
      "Epoch [2/5], Loss: 0.1286\n",
      "Epoch [3/5], Loss: 0.0669\n",
      "Epoch [4/5], Loss: 0.0504\n",
      "Epoch [5/5], Loss: 0.0345\n",
      "Client 16 Private Data Validation Accuracy: 68.75%\n",
      "\n",
      "--- Client 17 Local Training with T=0.3820, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.6971\n",
      "Epoch [2/5], Loss: 0.3280\n",
      "Epoch [3/5], Loss: 0.1752\n",
      "Epoch [4/5], Loss: 0.0778\n",
      "Epoch [5/5], Loss: 0.0519\n",
      "Client 17 Private Data Validation Accuracy: 61.72%\n",
      "\n",
      "--- Client 18 Local Training with T=0.3588, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.5339\n",
      "Epoch [2/5], Loss: 0.2056\n",
      "Epoch [3/5], Loss: 0.1112\n",
      "Epoch [4/5], Loss: 0.0618\n",
      "Epoch [5/5], Loss: 0.0367\n",
      "Client 18 Private Data Validation Accuracy: 78.12%\n",
      "\n",
      "--- Client 19 Local Training with T=0.3066, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.6061\n",
      "Epoch [2/5], Loss: 0.1608\n",
      "Epoch [3/5], Loss: 0.1142\n",
      "Epoch [4/5], Loss: 0.0325\n",
      "Epoch [5/5], Loss: 0.0314\n",
      "Client 19 Private Data Validation Accuracy: 66.15%\n",
      "\n",
      "--- Server Evaluation ---\n",
      "Server Test Data Validation Accuracy: 58.47%\n",
      "\n",
      "\n",
      "##################################################\n",
      "# RUNNING FedChill with Scaling Factor 1.5\n",
      "##################################################\n",
      "\n",
      "Data partitioning complete.\n",
      "Client 0 - Het Score: 0.8110, Initial temp: 0.2963\n",
      "Client 1 - Het Score: 0.9049, Initial temp: 0.2574\n",
      "Client 2 - Het Score: 0.8625, Initial temp: 0.2743\n",
      "Client 3 - Het Score: 0.8298, Initial temp: 0.2880\n",
      "Client 4 - Het Score: 0.7894, Initial temp: 0.3060\n",
      "Client 5 - Het Score: 0.6072, Initial temp: 0.4022\n",
      "Client 6 - Het Score: 0.7654, Initial temp: 0.3172\n",
      "Client 7 - Het Score: 0.4441, Initial temp: 0.5137\n",
      "Client 8 - Het Score: 0.8104, Initial temp: 0.2965\n",
      "Client 9 - Het Score: 1.0000, Initial temp: 0.2231\n",
      "Client 10 - Het Score: 0.9858, Initial temp: 0.2279\n",
      "Client 11 - Het Score: 0.8155, Initial temp: 0.2943\n",
      "Client 12 - Het Score: 0.6630, Initial temp: 0.3699\n",
      "Client 13 - Het Score: 0.5425, Initial temp: 0.4432\n",
      "Client 14 - Het Score: 0.8720, Initial temp: 0.2704\n",
      "Client 15 - Het Score: 0.7941, Initial temp: 0.3039\n",
      "Client 16 - Het Score: 0.6080, Initial temp: 0.4017\n",
      "Client 17 - Het Score: 0.9110, Initial temp: 0.2550\n",
      "Client 18 - Het Score: 0.8198, Initial temp: 0.2924\n",
      "Client 19 - Het Score: 0.8745, Initial temp: 0.2693\n",
      "\n",
      "==================== COMMUNICATION ROUND 1 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 2 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 3 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 4 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 5 ====================\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m/tmp/ipykernel_36/2023819718.py\u001b[0m in \u001b[0;36m<cell line: 0>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mresults\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrun_experiment\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      2\u001b[0m \u001b[0mplot_model_comparison\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresults\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      3\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      4\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresults\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/tmp/ipykernel_36/285767418.py\u001b[0m in \u001b[0;36mrun_experiment\u001b[0;34m()\u001b[0m\n\u001b[1;32m     12\u001b[0m         \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"{'#'*50}\\n\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     13\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 14\u001b[0;31m         \u001b[0mserver_acc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mclient_acc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfinal_model\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrun_fedchill\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mscaling_factor\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0msf\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     15\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     16\u001b[0m         \u001b[0mavg_client_acc\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmean\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mclient_acc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxis\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtolist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/tmp/ipykernel_36/985168388.py\u001b[0m in \u001b[0;36mrun_fedchill\u001b[0;34m(scaling_factor, model_class)\u001b[0m\n\u001b[1;32m     67\u001b[0m                 \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"Client {client_idx} Private Data Validation Accuracy: {local_acc:.2f}%\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     68\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 69\u001b[0;31m             \u001b[0mclient_testset_acc\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mevaluate_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mclient_models\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mclient_idx\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtest_loader\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     70\u001b[0m             \u001b[0mround_client_acc_list\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mclient_idx\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlocal_acc\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     71\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/tmp/ipykernel_36/3183112492.py\u001b[0m in \u001b[0;36mevaluate_model\u001b[0;34m(model, data_loader)\u001b[0m\n\u001b[1;32m    129\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtargets\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdata_loader\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    130\u001b[0m             \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtargets\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtargets\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 131\u001b[0;31m             \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    132\u001b[0m             \u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpredicted\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    133\u001b[0m             \u001b[0mtotal\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0mtargets\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1737\u001b[0m             \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compiled_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m  \u001b[0;31m# type: ignore[misc]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1738\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1739\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1740\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1741\u001b[0m     \u001b[0;31m# torchrec tests the code consistency with the following code\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1748\u001b[0m                 \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1749\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1750\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1751\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1752\u001b[0m         \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/tmp/ipykernel_36/1690506672.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m     73\u001b[0m         \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mview\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     74\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 75\u001b[0;31m         \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrelu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbn_fc1\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfc1\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     76\u001b[0m         \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdropout1\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     77\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__getattr__\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m   1913\u001b[0m     \u001b[0;31m# on `torch.nn.Module` and all its subclasses is largely disabled as a result. See:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1914\u001b[0m     \u001b[0;31m# https://github.com/pytorch/pytorch/pull/115074\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1915\u001b[0;31m     \u001b[0;32mdef\u001b[0m \u001b[0m__getattr__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mUnion\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"Module\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1916\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0;34m\"_parameters\"\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__dict__\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1917\u001b[0m             \u001b[0m_parameters\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__dict__\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"_parameters\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "results = run_experiment()\n",
    "plot_model_comparison(results)\n",
    "\n",
    "print(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-09-23T04:16:06.426959Z",
     "iopub.status.busy": "2025-09-23T04:16:06.426462Z",
     "iopub.status.idle": "2025-09-23T06:51:39.112726Z",
     "shell.execute_reply": "2025-09-23T06:51:39.111547Z",
     "shell.execute_reply.started": "2025-09-23T04:16:06.426929Z"
    },
    "trusted": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "================================================================================\n",
      "================================================================================\n",
      "\n",
      "\n",
      "\n",
      "##################################################\n",
      "# RUNNING FedChill with Scaling Factor 1.5\n",
      "##################################################\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 170M/170M [00:08<00:00, 21.1MB/s] \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Data partitioning complete.\n",
      "Client 0 - Het Score: 0.8110, Initial temp: 0.2963\n",
      "Client 1 - Het Score: 0.9049, Initial temp: 0.2574\n",
      "Client 2 - Het Score: 0.8625, Initial temp: 0.2743\n",
      "Client 3 - Het Score: 0.8298, Initial temp: 0.2880\n",
      "Client 4 - Het Score: 0.7894, Initial temp: 0.3060\n",
      "Client 5 - Het Score: 0.6072, Initial temp: 0.4022\n",
      "Client 6 - Het Score: 0.7654, Initial temp: 0.3172\n",
      "Client 7 - Het Score: 0.4441, Initial temp: 0.5137\n",
      "Client 8 - Het Score: 0.8104, Initial temp: 0.2965\n",
      "Client 9 - Het Score: 1.0000, Initial temp: 0.2231\n",
      "Client 10 - Het Score: 0.9858, Initial temp: 0.2279\n",
      "Client 11 - Het Score: 0.8155, Initial temp: 0.2943\n",
      "Client 12 - Het Score: 0.6630, Initial temp: 0.3699\n",
      "Client 13 - Het Score: 0.5425, Initial temp: 0.4432\n",
      "Client 14 - Het Score: 0.8720, Initial temp: 0.2704\n",
      "Client 15 - Het Score: 0.7941, Initial temp: 0.3039\n",
      "Client 16 - Het Score: 0.6080, Initial temp: 0.4017\n",
      "Client 17 - Het Score: 0.9110, Initial temp: 0.2550\n",
      "Client 18 - Het Score: 0.8198, Initial temp: 0.2924\n",
      "Client 19 - Het Score: 0.8745, Initial temp: 0.2693\n",
      "\n",
      "==================== COMMUNICATION ROUND 1 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 2 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 3 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 4 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 5 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 6 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 7 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 8 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 9 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 10 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 11 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 12 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 13 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 14 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 15 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 16 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 17 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 18 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 19 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 20 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 21 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 22 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 23 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 24 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 25 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 26 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 27 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 28 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 29 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 30 ====================\n",
      "\n",
      "--- LOCAL TRAINING OF CLIENTS ---\n",
      "\n",
      "--- Client 0 Local Training with T=0.2540, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.3422\n",
      "Epoch [2/5], Loss: 0.1217\n",
      "Epoch [3/5], Loss: 0.0731\n",
      "Epoch [4/5], Loss: 0.0631\n",
      "Epoch [5/5], Loss: 0.0500\n",
      "Client 0 Private Data Validation Accuracy: 73.44%\n",
      "Client 0 - decrease-temp to 0.241314\n",
      "\n",
      "--- Client 1 Local Training with T=0.2323, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.4282\n",
      "Epoch [2/5], Loss: 0.0772\n",
      "Epoch [3/5], Loss: 0.0372\n",
      "Epoch [4/5], Loss: 0.0437\n",
      "Epoch [5/5], Loss: 0.0224\n",
      "Client 1 Private Data Validation Accuracy: 82.81%\n",
      "\n",
      "--- Client 2 Local Training with T=0.2234, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.3114\n",
      "Epoch [2/5], Loss: 0.1067\n",
      "Epoch [3/5], Loss: 0.0822\n",
      "Epoch [4/5], Loss: 0.0695\n",
      "Epoch [5/5], Loss: 0.0208\n",
      "Client 2 Private Data Validation Accuracy: 72.40%\n",
      "\n",
      "--- Client 3 Local Training with T=0.2470, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.4431\n",
      "Epoch [2/5], Loss: 0.1932\n",
      "Epoch [3/5], Loss: 0.1336\n",
      "Epoch [4/5], Loss: 0.0470\n",
      "Epoch [5/5], Loss: 0.0241\n",
      "Client 3 Private Data Validation Accuracy: 72.40%\n",
      "\n",
      "--- Client 4 Local Training with T=0.2762, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.4202\n",
      "Epoch [2/5], Loss: 0.2390\n",
      "Epoch [3/5], Loss: 0.1581\n",
      "Epoch [4/5], Loss: 0.0545\n",
      "Epoch [5/5], Loss: 0.0362\n",
      "Client 4 Private Data Validation Accuracy: 70.31%\n",
      "\n",
      "--- Client 5 Local Training with T=0.3448, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.4544\n",
      "Epoch [2/5], Loss: 0.1353\n",
      "Epoch [3/5], Loss: 0.0772\n",
      "Epoch [4/5], Loss: 0.0547\n",
      "Epoch [5/5], Loss: 0.0376\n",
      "Client 5 Private Data Validation Accuracy: 67.71%\n",
      "Client 5 - decrease-temp to 0.327606\n",
      "\n",
      "--- Client 6 Local Training with T=0.2455, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.3464\n",
      "Epoch [2/5], Loss: 0.0747\n",
      "Epoch [3/5], Loss: 0.0461\n",
      "Epoch [4/5], Loss: 0.0109\n",
      "Epoch [5/5], Loss: 0.0194\n",
      "Client 6 Private Data Validation Accuracy: 73.44%\n",
      "\n",
      "--- Client 7 Local Training with T=0.4404, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.3116\n",
      "Epoch [2/5], Loss: 0.1401\n",
      "Epoch [3/5], Loss: 0.0940\n",
      "Epoch [4/5], Loss: 0.0662\n",
      "Epoch [5/5], Loss: 0.0497\n",
      "Client 7 Private Data Validation Accuracy: 59.77%\n",
      "\n",
      "--- Client 8 Local Training with T=0.2415, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.8148\n",
      "Epoch [2/5], Loss: 0.1409\n",
      "Epoch [3/5], Loss: 0.1645\n",
      "Epoch [4/5], Loss: 0.0463\n",
      "Epoch [5/5], Loss: 0.0556\n",
      "Client 8 Private Data Validation Accuracy: 63.28%\n",
      "Client 8 - decrease-temp to 0.229455\n",
      "\n",
      "--- Client 9 Local Training with T=0.1817, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.3202\n",
      "Epoch [2/5], Loss: 0.2243\n",
      "Epoch [3/5], Loss: 0.1660\n",
      "Epoch [4/5], Loss: 0.1343\n",
      "Epoch [5/5], Loss: 0.0555\n",
      "Client 9 Private Data Validation Accuracy: 71.88%\n",
      "\n",
      "--- Client 10 Local Training with T=0.1954, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.5447\n",
      "Epoch [2/5], Loss: 0.2155\n",
      "Epoch [3/5], Loss: 0.1448\n",
      "Epoch [4/5], Loss: 0.0732\n",
      "Epoch [5/5], Loss: 0.0738\n",
      "Client 10 Private Data Validation Accuracy: 73.05%\n",
      "\n",
      "--- Client 11 Local Training with T=0.2277, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 1.6712\n",
      "Epoch [2/5], Loss: 0.2625\n",
      "Epoch [3/5], Loss: 0.0733\n",
      "Epoch [4/5], Loss: 0.0291\n",
      "Epoch [5/5], Loss: 0.0187\n",
      "Client 11 Private Data Validation Accuracy: 71.88%\n",
      "Client 11 - decrease-temp to 0.216303\n",
      "\n",
      "--- Client 12 Local Training with T=0.3514, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.5515\n",
      "Epoch [2/5], Loss: 0.0792\n",
      "Epoch [3/5], Loss: 0.0468\n",
      "Epoch [4/5], Loss: 0.0458\n",
      "Epoch [5/5], Loss: 0.0262\n",
      "Client 12 Private Data Validation Accuracy: 80.73%\n",
      "\n",
      "--- Client 13 Local Training with T=0.3258, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.4899\n",
      "Epoch [2/5], Loss: 0.2247\n",
      "Epoch [3/5], Loss: 0.1273\n",
      "Epoch [4/5], Loss: 0.0557\n",
      "Epoch [5/5], Loss: 0.0537\n",
      "Client 13 Private Data Validation Accuracy: 64.06%\n",
      "\n",
      "--- Client 14 Local Training with T=0.2202, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.9878\n",
      "Epoch [2/5], Loss: 0.2361\n",
      "Epoch [3/5], Loss: 0.2122\n",
      "Epoch [4/5], Loss: 0.0861\n",
      "Epoch [5/5], Loss: 0.0371\n",
      "Client 14 Private Data Validation Accuracy: 57.81%\n",
      "\n",
      "--- Client 15 Local Training with T=0.2605, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.6695\n",
      "Epoch [2/5], Loss: 0.1530\n",
      "Epoch [3/5], Loss: 0.0741\n",
      "Epoch [4/5], Loss: 0.0447\n",
      "Epoch [5/5], Loss: 0.0333\n",
      "Client 15 Private Data Validation Accuracy: 76.56%\n",
      "\n",
      "--- Client 16 Local Training with T=0.3626, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.4381\n",
      "Epoch [2/5], Loss: 0.1049\n",
      "Epoch [3/5], Loss: 0.0572\n",
      "Epoch [4/5], Loss: 0.0458\n",
      "Epoch [5/5], Loss: 0.0330\n",
      "Client 16 Private Data Validation Accuracy: 68.23%\n",
      "Client 16 - decrease-temp to 0.344425\n",
      "\n",
      "--- Client 17 Local Training with T=0.2186, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.6615\n",
      "Epoch [2/5], Loss: 0.4268\n",
      "Epoch [3/5], Loss: 0.2097\n",
      "Epoch [4/5], Loss: 0.1143\n",
      "Epoch [5/5], Loss: 0.0383\n",
      "Client 17 Private Data Validation Accuracy: 53.91%\n",
      "\n",
      "--- Client 18 Local Training with T=0.2507, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.3775\n",
      "Epoch [2/5], Loss: 0.1346\n",
      "Epoch [3/5], Loss: 0.1099\n",
      "Epoch [4/5], Loss: 0.0624\n",
      "Epoch [5/5], Loss: 0.0326\n",
      "Client 18 Private Data Validation Accuracy: 70.31%\n",
      "Client 18 - decrease-temp to 0.238157\n",
      "\n",
      "--- Client 19 Local Training with T=0.2194, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.4152\n",
      "Epoch [2/5], Loss: 0.1284\n",
      "Epoch [3/5], Loss: 0.1091\n",
      "Epoch [4/5], Loss: 0.0467\n",
      "Epoch [5/5], Loss: 0.0438\n",
      "Client 19 Private Data Validation Accuracy: 69.79%\n",
      "\n",
      "--- Server Evaluation ---\n",
      "Server Test Data Validation Accuracy: 58.67%\n",
      "\n",
      "\n",
      "##################################################\n",
      "# RUNNING FedChill with Scaling Factor 2.0\n",
      "##################################################\n",
      "\n",
      "Data partitioning complete.\n",
      "Client 0 - Het Score: 0.8110, Initial temp: 0.1975\n",
      "Client 1 - Het Score: 0.9049, Initial temp: 0.1637\n",
      "Client 2 - Het Score: 0.8625, Initial temp: 0.1782\n",
      "Client 3 - Het Score: 0.8298, Initial temp: 0.1902\n",
      "Client 4 - Het Score: 0.7894, Initial temp: 0.2062\n",
      "Client 5 - Het Score: 0.6072, Initial temp: 0.2969\n",
      "Client 6 - Het Score: 0.7654, Initial temp: 0.2163\n",
      "Client 7 - Het Score: 0.4441, Initial temp: 0.4114\n",
      "Client 8 - Het Score: 0.8104, Initial temp: 0.1977\n",
      "Client 9 - Het Score: 1.0000, Initial temp: 0.1353\n",
      "Client 10 - Het Score: 0.9858, Initial temp: 0.1392\n",
      "Client 11 - Het Score: 0.8155, Initial temp: 0.1957\n",
      "Client 12 - Het Score: 0.6630, Initial temp: 0.2655\n",
      "Client 13 - Het Score: 0.5425, Initial temp: 0.3379\n",
      "Client 14 - Het Score: 0.8720, Initial temp: 0.1748\n",
      "Client 15 - Het Score: 0.7941, Initial temp: 0.2043\n",
      "Client 16 - Het Score: 0.6080, Initial temp: 0.2964\n",
      "Client 17 - Het Score: 0.9110, Initial temp: 0.1617\n",
      "Client 18 - Het Score: 0.8198, Initial temp: 0.1941\n",
      "Client 19 - Het Score: 0.8745, Initial temp: 0.1739\n",
      "\n",
      "==================== COMMUNICATION ROUND 1 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 2 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 3 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 4 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 5 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 6 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 7 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 8 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 9 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 10 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 11 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 12 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 13 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 14 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 15 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 16 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 17 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 18 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 19 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 20 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 21 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 22 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 23 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 24 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 25 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 26 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 27 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 28 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 29 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 30 ====================\n",
      "\n",
      "--- LOCAL TRAINING OF CLIENTS ---\n",
      "\n",
      "--- Client 0 Local Training with T=0.1783, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.3663\n",
      "Epoch [2/5], Loss: 0.2093\n",
      "Epoch [3/5], Loss: 0.0624\n",
      "Epoch [4/5], Loss: 0.0620\n",
      "Epoch [5/5], Loss: 0.0389\n",
      "Client 0 Private Data Validation Accuracy: 72.92%\n",
      "\n",
      "--- Client 1 Local Training with T=0.1555, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.4287\n",
      "Epoch [2/5], Loss: 0.0868\n",
      "Epoch [3/5], Loss: 0.0330\n",
      "Epoch [4/5], Loss: 0.0222\n",
      "Epoch [5/5], Loss: 0.0223\n",
      "Client 1 Private Data Validation Accuracy: 84.38%\n",
      "\n",
      "--- Client 2 Local Training with T=0.1451, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.3193\n",
      "Epoch [2/5], Loss: 0.1801\n",
      "Epoch [3/5], Loss: 0.1232\n",
      "Epoch [4/5], Loss: 0.0974\n",
      "Epoch [5/5], Loss: 0.0212\n",
      "Client 2 Private Data Validation Accuracy: 70.31%\n",
      "\n",
      "--- Client 3 Local Training with T=0.1472, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.5373\n",
      "Epoch [2/5], Loss: 0.1056\n",
      "Epoch [3/5], Loss: 0.1047\n",
      "Epoch [4/5], Loss: 0.0506\n",
      "Epoch [5/5], Loss: 0.0336\n",
      "Client 3 Private Data Validation Accuracy: 69.27%\n",
      "\n",
      "--- Client 4 Local Training with T=0.1861, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.4384\n",
      "Epoch [2/5], Loss: 0.1861\n",
      "Epoch [3/5], Loss: 0.1315\n",
      "Epoch [4/5], Loss: 0.0584\n",
      "Epoch [5/5], Loss: 0.0343\n",
      "Client 4 Private Data Validation Accuracy: 71.88%\n",
      "\n",
      "--- Client 5 Local Training with T=0.2546, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.4050\n",
      "Epoch [2/5], Loss: 0.1187\n",
      "Epoch [3/5], Loss: 0.0755\n",
      "Epoch [4/5], Loss: 0.0442\n",
      "Epoch [5/5], Loss: 0.0229\n",
      "Client 5 Private Data Validation Accuracy: 71.88%\n",
      "\n",
      "--- Client 6 Local Training with T=0.1762, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.3635\n",
      "Epoch [2/5], Loss: 0.0940\n",
      "Epoch [3/5], Loss: 0.0460\n",
      "Epoch [4/5], Loss: 0.0243\n",
      "Epoch [5/5], Loss: 0.0281\n",
      "Client 6 Private Data Validation Accuracy: 76.04%\n",
      "\n",
      "--- Client 7 Local Training with T=0.3908, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.3474\n",
      "Epoch [2/5], Loss: 0.1737\n",
      "Epoch [3/5], Loss: 0.1086\n",
      "Epoch [4/5], Loss: 0.0807\n",
      "Epoch [5/5], Loss: 0.0639\n",
      "Client 7 Private Data Validation Accuracy: 66.02%\n",
      "\n",
      "--- Client 8 Local Training with T=0.1530, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 1.0811\n",
      "Epoch [2/5], Loss: 0.2226\n",
      "Epoch [3/5], Loss: 0.1376\n",
      "Epoch [4/5], Loss: 0.0367\n",
      "Epoch [5/5], Loss: 0.0792\n",
      "Client 8 Private Data Validation Accuracy: 61.72%\n",
      "\n",
      "--- Client 9 Local Training with T=0.1047, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.2774\n",
      "Epoch [2/5], Loss: 0.2954\n",
      "Epoch [3/5], Loss: 0.3426\n",
      "Epoch [4/5], Loss: 0.1351\n",
      "Epoch [5/5], Loss: 0.1898\n",
      "Client 9 Private Data Validation Accuracy: 70.31%\n",
      "\n",
      "--- Client 10 Local Training with T=0.1134, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 1.2980\n",
      "Epoch [2/5], Loss: 0.3001\n",
      "Epoch [3/5], Loss: 0.1473\n",
      "Epoch [4/5], Loss: 0.1064\n",
      "Epoch [5/5], Loss: 0.1268\n",
      "Client 10 Private Data Validation Accuracy: 71.09%\n",
      "\n",
      "--- Client 11 Local Training with T=0.1766, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 1.2987\n",
      "Epoch [2/5], Loss: 0.2498\n",
      "Epoch [3/5], Loss: 0.0616\n",
      "Epoch [4/5], Loss: 0.0311\n",
      "Epoch [5/5], Loss: 0.0193\n",
      "Client 11 Private Data Validation Accuracy: 70.31%\n",
      "\n",
      "--- Client 12 Local Training with T=0.2397, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.5507\n",
      "Epoch [2/5], Loss: 0.0747\n",
      "Epoch [3/5], Loss: 0.0643\n",
      "Epoch [4/5], Loss: 0.0238\n",
      "Epoch [5/5], Loss: 0.0401\n",
      "Client 12 Private Data Validation Accuracy: 80.21%\n",
      "\n",
      "--- Client 13 Local Training with T=0.2614, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.4930\n",
      "Epoch [2/5], Loss: 0.2187\n",
      "Epoch [3/5], Loss: 0.1235\n",
      "Epoch [4/5], Loss: 0.0479\n",
      "Epoch [5/5], Loss: 0.0508\n",
      "Client 13 Private Data Validation Accuracy: 67.19%\n",
      "\n",
      "--- Client 14 Local Training with T=0.1499, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 1.1126\n",
      "Epoch [2/5], Loss: 0.4370\n",
      "Epoch [3/5], Loss: 0.2293\n",
      "Epoch [4/5], Loss: 0.0992\n",
      "Epoch [5/5], Loss: 0.0539\n",
      "Client 14 Private Data Validation Accuracy: 65.62%\n",
      "\n",
      "--- Client 15 Local Training with T=0.1752, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.6194\n",
      "Epoch [2/5], Loss: 0.2222\n",
      "Epoch [3/5], Loss: 0.0629\n",
      "Epoch [4/5], Loss: 0.0379\n",
      "Epoch [5/5], Loss: 0.0279\n",
      "Client 15 Private Data Validation Accuracy: 78.12%\n",
      "\n",
      "--- Client 16 Local Training with T=0.2675, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.4559\n",
      "Epoch [2/5], Loss: 0.1208\n",
      "Epoch [3/5], Loss: 0.0579\n",
      "Epoch [4/5], Loss: 0.0373\n",
      "Epoch [5/5], Loss: 0.0341\n",
      "Client 16 Private Data Validation Accuracy: 69.79%\n",
      "\n",
      "--- Client 17 Local Training with T=0.1387, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 1.1703\n",
      "Epoch [2/5], Loss: 0.5493\n",
      "Epoch [3/5], Loss: 0.2962\n",
      "Epoch [4/5], Loss: 0.1697\n",
      "Epoch [5/5], Loss: 0.1417\n",
      "Client 17 Private Data Validation Accuracy: 51.56%\n",
      "\n",
      "--- Client 18 Local Training with T=0.1664, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.4268\n",
      "Epoch [2/5], Loss: 0.1602\n",
      "Epoch [3/5], Loss: 0.1368\n",
      "Epoch [4/5], Loss: 0.0627\n",
      "Epoch [5/5], Loss: 0.0301\n",
      "Client 18 Private Data Validation Accuracy: 73.44%\n",
      "\n",
      "--- Client 19 Local Training with T=0.1491, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.5705\n",
      "Epoch [2/5], Loss: 0.2615\n",
      "Epoch [3/5], Loss: 0.1194\n",
      "Epoch [4/5], Loss: 0.0290\n",
      "Epoch [5/5], Loss: 0.0103\n",
      "Client 19 Private Data Validation Accuracy: 70.31%\n",
      "\n",
      "--- Server Evaluation ---\n",
      "Server Test Data Validation Accuracy: 59.38%\n",
      "\n",
      "\n",
      "##################################################\n",
      "# RUNNING FedChill with Scaling Factor 2.5\n",
      "##################################################\n",
      "\n",
      "Data partitioning complete.\n",
      "Client 0 - Het Score: 0.8110, Initial temp: 0.1317\n",
      "Client 1 - Het Score: 0.9049, Initial temp: 0.1041\n",
      "Client 2 - Het Score: 0.8625, Initial temp: 0.1158\n",
      "Client 3 - Het Score: 0.8298, Initial temp: 0.1256\n",
      "Client 4 - Het Score: 0.7894, Initial temp: 0.1390\n",
      "Client 5 - Het Score: 0.6072, Initial temp: 0.2192\n",
      "Client 6 - Het Score: 0.7654, Initial temp: 0.1475\n",
      "Client 7 - Het Score: 0.4441, Initial temp: 0.3294\n",
      "Client 8 - Het Score: 0.8104, Initial temp: 0.1319\n",
      "Client 9 - Het Score: 1.0000, Initial temp: 0.0821\n",
      "Client 10 - Het Score: 0.9858, Initial temp: 0.0850\n",
      "Client 11 - Het Score: 0.8155, Initial temp: 0.1302\n",
      "Client 12 - Het Score: 0.6630, Initial temp: 0.1906\n",
      "Client 13 - Het Score: 0.5425, Initial temp: 0.2576\n",
      "Client 14 - Het Score: 0.8720, Initial temp: 0.1131\n",
      "Client 15 - Het Score: 0.7941, Initial temp: 0.1374\n",
      "Client 16 - Het Score: 0.6080, Initial temp: 0.2187\n",
      "Client 17 - Het Score: 0.9110, Initial temp: 0.1026\n",
      "Client 18 - Het Score: 0.8198, Initial temp: 0.1288\n",
      "Client 19 - Het Score: 0.8745, Initial temp: 0.1123\n",
      "\n",
      "==================== COMMUNICATION ROUND 1 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 2 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 3 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 4 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 5 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 6 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 7 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 8 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 9 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 10 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 11 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 12 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 13 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 14 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 15 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 16 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 17 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 18 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 19 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 20 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 21 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 22 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 23 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 24 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 25 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 26 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 27 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 28 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 29 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 30 ====================\n",
      "\n",
      "--- LOCAL TRAINING OF CLIENTS ---\n",
      "\n",
      "--- Client 0 Local Training with T=0.1251, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.4780\n",
      "Epoch [2/5], Loss: 0.2056\n",
      "Epoch [3/5], Loss: 0.1002\n",
      "Epoch [4/5], Loss: 0.0766\n",
      "Epoch [5/5], Loss: 0.0650\n",
      "Client 0 Private Data Validation Accuracy: 71.35%\n",
      "\n",
      "--- Client 1 Local Training with T=0.0989, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.6677\n",
      "Epoch [2/5], Loss: 0.1975\n",
      "Epoch [3/5], Loss: 0.0788\n",
      "Epoch [4/5], Loss: 0.0526\n",
      "Epoch [5/5], Loss: 0.0928\n",
      "Client 1 Private Data Validation Accuracy: 82.81%\n",
      "Client 1 - decrease-temp to 0.093971\n",
      "\n",
      "--- Client 2 Local Training with T=0.1045, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.9800\n",
      "Epoch [2/5], Loss: 0.3452\n",
      "Epoch [3/5], Loss: 0.2172\n",
      "Epoch [4/5], Loss: 0.0904\n",
      "Epoch [5/5], Loss: 0.0452\n",
      "Client 2 Private Data Validation Accuracy: 68.23%\n",
      "\n",
      "--- Client 3 Local Training with T=0.1077, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.9394\n",
      "Epoch [2/5], Loss: 0.1662\n",
      "Epoch [3/5], Loss: 0.2305\n",
      "Epoch [4/5], Loss: 0.1480\n",
      "Epoch [5/5], Loss: 0.0402\n",
      "Client 3 Private Data Validation Accuracy: 69.79%\n",
      "\n",
      "--- Client 4 Local Training with T=0.1132, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.7272\n",
      "Epoch [2/5], Loss: 0.5227\n",
      "Epoch [3/5], Loss: 0.4224\n",
      "Epoch [4/5], Loss: 0.1022\n",
      "Epoch [5/5], Loss: 0.0768\n",
      "Client 4 Private Data Validation Accuracy: 68.75%\n",
      "\n",
      "--- Client 5 Local Training with T=0.1978, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.4815\n",
      "Epoch [2/5], Loss: 0.1616\n",
      "Epoch [3/5], Loss: 0.0903\n",
      "Epoch [4/5], Loss: 0.1090\n",
      "Epoch [5/5], Loss: 0.0341\n",
      "Client 5 Private Data Validation Accuracy: 72.92%\n",
      "\n",
      "--- Client 6 Local Training with T=0.1202, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.4169\n",
      "Epoch [2/5], Loss: 0.0873\n",
      "Epoch [3/5], Loss: 0.0642\n",
      "Epoch [4/5], Loss: 0.0670\n",
      "Epoch [5/5], Loss: 0.0163\n",
      "Client 6 Private Data Validation Accuracy: 74.48%\n",
      "\n",
      "--- Client 7 Local Training with T=0.2973, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.4327\n",
      "Epoch [2/5], Loss: 0.1948\n",
      "Epoch [3/5], Loss: 0.1015\n",
      "Epoch [4/5], Loss: 0.0893\n",
      "Epoch [5/5], Loss: 0.0681\n",
      "Client 7 Private Data Validation Accuracy: 63.28%\n",
      "\n",
      "--- Client 8 Local Training with T=0.1253, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.9931\n",
      "Epoch [2/5], Loss: 0.1971\n",
      "Epoch [3/5], Loss: 0.1576\n",
      "Epoch [4/5], Loss: 0.1131\n",
      "Epoch [5/5], Loss: 0.0610\n",
      "Client 8 Private Data Validation Accuracy: 67.19%\n",
      "\n",
      "--- Client 9 Local Training with T=0.0669, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 1.3038\n",
      "Epoch [2/5], Loss: 0.3158\n",
      "Epoch [3/5], Loss: 0.1901\n",
      "Epoch [4/5], Loss: 0.1836\n",
      "Epoch [5/5], Loss: 0.1018\n",
      "Client 9 Private Data Validation Accuracy: 72.66%\n",
      "\n",
      "--- Client 10 Local Training with T=0.0729, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 3.3269\n",
      "Epoch [2/5], Loss: 0.8916\n",
      "Epoch [3/5], Loss: 0.4065\n",
      "Epoch [4/5], Loss: 0.3226\n",
      "Epoch [5/5], Loss: 0.2936\n",
      "Client 10 Private Data Validation Accuracy: 71.09%\n",
      "Client 10 - decrease-temp to 0.069272\n",
      "\n",
      "--- Client 11 Local Training with T=0.1007, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 1.5787\n",
      "Epoch [2/5], Loss: 0.4680\n",
      "Epoch [3/5], Loss: 0.1994\n",
      "Epoch [4/5], Loss: 0.0654\n",
      "Epoch [5/5], Loss: 0.0279\n",
      "Client 11 Private Data Validation Accuracy: 82.81%\n",
      "\n",
      "--- Client 12 Local Training with T=0.1634, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.5755\n",
      "Epoch [2/5], Loss: 0.0886\n",
      "Epoch [3/5], Loss: 0.0783\n",
      "Epoch [4/5], Loss: 0.0476\n",
      "Epoch [5/5], Loss: 0.0229\n",
      "Client 12 Private Data Validation Accuracy: 79.69%\n",
      "\n",
      "--- Client 13 Local Training with T=0.2209, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.6375\n",
      "Epoch [2/5], Loss: 0.2370\n",
      "Epoch [3/5], Loss: 0.1509\n",
      "Epoch [4/5], Loss: 0.0718\n",
      "Epoch [5/5], Loss: 0.0686\n",
      "Client 13 Private Data Validation Accuracy: 64.84%\n",
      "\n",
      "--- Client 14 Local Training with T=0.0921, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 2.4345\n",
      "Epoch [2/5], Loss: 1.0462\n",
      "Epoch [3/5], Loss: 0.6851\n",
      "Epoch [4/5], Loss: 0.2741\n",
      "Epoch [5/5], Loss: 0.1441\n",
      "Client 14 Private Data Validation Accuracy: 62.50%\n",
      "Client 14 - decrease-temp to 0.087477\n",
      "\n",
      "--- Client 15 Local Training with T=0.1063, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.9635\n",
      "Epoch [2/5], Loss: 0.2795\n",
      "Epoch [3/5], Loss: 0.0943\n",
      "Epoch [4/5], Loss: 0.0727\n",
      "Epoch [5/5], Loss: 0.0314\n",
      "Client 15 Private Data Validation Accuracy: 82.29%\n",
      "\n",
      "--- Client 16 Local Training with T=0.2078, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.6698\n",
      "Epoch [2/5], Loss: 0.1420\n",
      "Epoch [3/5], Loss: 0.0689\n",
      "Epoch [4/5], Loss: 0.0636\n",
      "Epoch [5/5], Loss: 0.0411\n",
      "Client 16 Private Data Validation Accuracy: 64.58%\n",
      "\n",
      "--- Client 17 Local Training with T=0.0794, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 2.0389\n",
      "Epoch [2/5], Loss: 1.9240\n",
      "Epoch [3/5], Loss: 1.2667\n",
      "Epoch [4/5], Loss: 1.3243\n",
      "Epoch [5/5], Loss: 0.5525\n",
      "Client 17 Private Data Validation Accuracy: 55.47%\n",
      "\n",
      "--- Client 18 Local Training with T=0.1224, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.5715\n",
      "Epoch [2/5], Loss: 0.3235\n",
      "Epoch [3/5], Loss: 0.1060\n",
      "Epoch [4/5], Loss: 0.0704\n",
      "Epoch [5/5], Loss: 0.0491\n",
      "Client 18 Private Data Validation Accuracy: 74.22%\n",
      "\n",
      "--- Client 19 Local Training with T=0.0915, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 1.0575\n",
      "Epoch [2/5], Loss: 0.6644\n",
      "Epoch [3/5], Loss: 1.0220\n",
      "Epoch [4/5], Loss: 0.2668\n",
      "Epoch [5/5], Loss: 0.2078\n",
      "Client 19 Private Data Validation Accuracy: 67.19%\n",
      "\n",
      "--- Server Evaluation ---\n",
      "Server Test Data Validation Accuracy: 59.81%\n",
      "\n",
      "\n",
      "##################################################\n",
      "# RUNNING FedChill with Scaling Factor 3.0\n",
      "##################################################\n",
      "\n",
      "Data partitioning complete.\n",
      "Client 0 - Het Score: 0.8110, Initial temp: 0.0878\n",
      "Client 1 - Het Score: 0.9049, Initial temp: 0.0662\n",
      "Client 2 - Het Score: 0.8625, Initial temp: 0.0752\n",
      "Client 3 - Het Score: 0.8298, Initial temp: 0.0830\n",
      "Client 4 - Het Score: 0.7894, Initial temp: 0.0936\n",
      "Client 5 - Het Score: 0.6072, Initial temp: 0.1618\n",
      "Client 6 - Het Score: 0.7654, Initial temp: 0.1006\n",
      "Client 7 - Het Score: 0.4441, Initial temp: 0.2638\n",
      "Client 8 - Het Score: 0.8104, Initial temp: 0.0879\n",
      "Client 9 - Het Score: 1.0000, Initial temp: 0.0500\n",
      "Client 10 - Het Score: 0.9858, Initial temp: 0.0520\n",
      "Client 11 - Het Score: 0.8155, Initial temp: 0.0866\n",
      "Client 12 - Het Score: 0.6630, Initial temp: 0.1368\n",
      "Client 13 - Het Score: 0.5425, Initial temp: 0.1964\n",
      "Client 14 - Het Score: 0.8720, Initial temp: 0.0731\n",
      "Client 15 - Het Score: 0.7941, Initial temp: 0.0923\n",
      "Client 16 - Het Score: 0.6080, Initial temp: 0.1614\n",
      "Client 17 - Het Score: 0.9110, Initial temp: 0.0650\n",
      "Client 18 - Het Score: 0.8198, Initial temp: 0.0855\n",
      "Client 19 - Het Score: 0.8745, Initial temp: 0.0725\n",
      "\n",
      "==================== COMMUNICATION ROUND 1 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 2 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 3 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 4 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 5 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 6 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 7 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 8 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 9 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 10 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 11 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 12 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 13 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 14 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 15 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 16 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 17 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 18 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 19 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 20 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 21 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 22 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 23 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 24 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 25 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 26 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 27 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 28 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 29 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 30 ====================\n",
      "\n",
      "--- LOCAL TRAINING OF CLIENTS ---\n",
      "\n",
      "--- Client 0 Local Training with T=0.0753, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.5428\n",
      "Epoch [2/5], Loss: 0.3449\n",
      "Epoch [3/5], Loss: 0.1985\n",
      "Epoch [4/5], Loss: 0.1380\n",
      "Epoch [5/5], Loss: 0.1769\n",
      "Client 0 Private Data Validation Accuracy: 64.58%\n",
      "\n",
      "--- Client 1 Local Training with T=0.0598, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.6016\n",
      "Epoch [2/5], Loss: 0.3025\n",
      "Epoch [3/5], Loss: 0.2037\n",
      "Epoch [4/5], Loss: 0.1952\n",
      "Epoch [5/5], Loss: 0.0837\n",
      "Client 1 Private Data Validation Accuracy: 83.85%\n",
      "\n",
      "--- Client 2 Local Training with T=0.0613, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.8151\n",
      "Epoch [2/5], Loss: 0.4131\n",
      "Epoch [3/5], Loss: 0.2097\n",
      "Epoch [4/5], Loss: 0.1840\n",
      "Epoch [5/5], Loss: 0.1702\n",
      "Client 2 Private Data Validation Accuracy: 66.15%\n",
      "\n",
      "--- Client 3 Local Training with T=0.0749, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.7832\n",
      "Epoch [2/5], Loss: 0.1949\n",
      "Epoch [3/5], Loss: 0.1943\n",
      "Epoch [4/5], Loss: 0.1242\n",
      "Epoch [5/5], Loss: 0.0804\n",
      "Client 3 Private Data Validation Accuracy: 70.31%\n",
      "\n",
      "--- Client 4 Local Training with T=0.0845, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.5881\n",
      "Epoch [2/5], Loss: 0.3126\n",
      "Epoch [3/5], Loss: 0.2042\n",
      "Epoch [4/5], Loss: 0.1353\n",
      "Epoch [5/5], Loss: 0.1150\n",
      "Client 4 Private Data Validation Accuracy: 70.31%\n",
      "\n",
      "--- Client 5 Local Training with T=0.1318, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.5547\n",
      "Epoch [2/5], Loss: 0.2034\n",
      "Epoch [3/5], Loss: 0.2062\n",
      "Epoch [4/5], Loss: 0.1495\n",
      "Epoch [5/5], Loss: 0.1384\n",
      "Client 5 Private Data Validation Accuracy: 70.83%\n",
      "\n",
      "--- Client 6 Local Training with T=0.0779, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.4083\n",
      "Epoch [2/5], Loss: 0.1338\n",
      "Epoch [3/5], Loss: 0.1242\n",
      "Epoch [4/5], Loss: 0.0663\n",
      "Epoch [5/5], Loss: 0.0786\n",
      "Client 6 Private Data Validation Accuracy: 68.23%\n",
      "Client 6 - decrease-temp to 0.073971\n",
      "\n",
      "--- Client 7 Local Training with T=0.2262, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.4704\n",
      "Epoch [2/5], Loss: 0.2065\n",
      "Epoch [3/5], Loss: 0.1163\n",
      "Epoch [4/5], Loss: 0.1064\n",
      "Epoch [5/5], Loss: 0.0640\n",
      "Client 7 Private Data Validation Accuracy: 58.20%\n",
      "\n",
      "--- Client 8 Local Training with T=0.0835, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.8642\n",
      "Epoch [2/5], Loss: 0.2655\n",
      "Epoch [3/5], Loss: 0.2644\n",
      "Epoch [4/5], Loss: 0.1132\n",
      "Epoch [5/5], Loss: 0.1471\n",
      "Client 8 Private Data Validation Accuracy: 66.41%\n",
      "\n",
      "--- Client 9 Local Training with T=0.0500, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.6104\n",
      "Epoch [2/5], Loss: 0.2479\n",
      "Epoch [3/5], Loss: 0.1820\n",
      "Epoch [4/5], Loss: 0.1014\n",
      "Epoch [5/5], Loss: 0.1116\n",
      "Client 9 Private Data Validation Accuracy: 75.78%\n",
      "\n",
      "--- Client 10 Local Training with T=0.0520, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 1.3132\n",
      "Epoch [2/5], Loss: 0.2766\n",
      "Epoch [3/5], Loss: 0.2218\n",
      "Epoch [4/5], Loss: 0.1189\n",
      "Epoch [5/5], Loss: 0.0970\n",
      "Client 10 Private Data Validation Accuracy: 67.19%\n",
      "\n",
      "--- Client 11 Local Training with T=0.0636, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 1.4989\n",
      "Epoch [2/5], Loss: 0.3165\n",
      "Epoch [3/5], Loss: 0.1832\n",
      "Epoch [4/5], Loss: 0.1258\n",
      "Epoch [5/5], Loss: 0.1265\n",
      "Client 11 Private Data Validation Accuracy: 82.81%\n",
      "\n",
      "--- Client 12 Local Training with T=0.1173, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.5635\n",
      "Epoch [2/5], Loss: 0.1058\n",
      "Epoch [3/5], Loss: 0.1943\n",
      "Epoch [4/5], Loss: 0.0923\n",
      "Epoch [5/5], Loss: 0.0417\n",
      "Client 12 Private Data Validation Accuracy: 81.25%\n",
      "\n",
      "--- Client 13 Local Training with T=0.1684, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.6926\n",
      "Epoch [2/5], Loss: 0.3526\n",
      "Epoch [3/5], Loss: 0.1843\n",
      "Epoch [4/5], Loss: 0.1119\n",
      "Epoch [5/5], Loss: 0.0819\n",
      "Client 13 Private Data Validation Accuracy: 57.81%\n",
      "Client 13 - decrease-temp to 0.159969\n",
      "\n",
      "--- Client 14 Local Training with T=0.0660, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 1.4550\n",
      "Epoch [2/5], Loss: 0.4627\n",
      "Epoch [3/5], Loss: 0.3255\n",
      "Epoch [4/5], Loss: 0.1570\n",
      "Epoch [5/5], Loss: 0.1241\n",
      "Client 14 Private Data Validation Accuracy: 62.50%\n",
      "\n",
      "--- Client 15 Local Training with T=0.0752, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.6518\n",
      "Epoch [2/5], Loss: 0.3051\n",
      "Epoch [3/5], Loss: 0.1228\n",
      "Epoch [4/5], Loss: 0.0925\n",
      "Epoch [5/5], Loss: 0.0949\n",
      "Client 15 Private Data Validation Accuracy: 80.73%\n",
      "\n",
      "--- Client 16 Local Training with T=0.1384, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.6155\n",
      "Epoch [2/5], Loss: 0.1939\n",
      "Epoch [3/5], Loss: 0.1328\n",
      "Epoch [4/5], Loss: 0.1884\n",
      "Epoch [5/5], Loss: 0.0988\n",
      "Client 16 Private Data Validation Accuracy: 70.83%\n",
      "\n",
      "--- Client 17 Local Training with T=0.0530, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 1.0731\n",
      "Epoch [2/5], Loss: 0.5438\n",
      "Epoch [3/5], Loss: 0.2968\n",
      "Epoch [4/5], Loss: 0.2043\n",
      "Epoch [5/5], Loss: 0.1409\n",
      "Client 17 Private Data Validation Accuracy: 57.81%\n",
      "\n",
      "--- Client 18 Local Training with T=0.0733, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.6712\n",
      "Epoch [2/5], Loss: 0.4104\n",
      "Epoch [3/5], Loss: 0.1826\n",
      "Epoch [4/5], Loss: 0.1013\n",
      "Epoch [5/5], Loss: 0.0881\n",
      "Client 18 Private Data Validation Accuracy: 72.66%\n",
      "\n",
      "--- Client 19 Local Training with T=0.0655, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 1.0307\n",
      "Epoch [2/5], Loss: 0.3430\n",
      "Epoch [3/5], Loss: 0.2952\n",
      "Epoch [4/5], Loss: 0.1372\n",
      "Epoch [5/5], Loss: 0.1692\n",
      "Client 19 Private Data Validation Accuracy: 66.67%\n",
      "\n",
      "--- Server Evaluation ---\n",
      "Server Test Data Validation Accuracy: 56.20%\n",
      "\n",
      "\n",
      "##################################################\n",
      "# RUNNING FedChill with Scaling Factor 3.5\n",
      "##################################################\n",
      "\n",
      "Data partitioning complete.\n",
      "Client 0 - Het Score: 0.8110, Initial temp: 0.0585\n",
      "Client 1 - Het Score: 0.9049, Initial temp: 0.0500\n",
      "Client 2 - Het Score: 0.8625, Initial temp: 0.0500\n",
      "Client 3 - Het Score: 0.8298, Initial temp: 0.0548\n",
      "Client 4 - Het Score: 0.7894, Initial temp: 0.0631\n",
      "Client 5 - Het Score: 0.6072, Initial temp: 0.1194\n",
      "Client 6 - Het Score: 0.7654, Initial temp: 0.0686\n",
      "Client 7 - Het Score: 0.4441, Initial temp: 0.2113\n",
      "Client 8 - Het Score: 0.8104, Initial temp: 0.0586\n",
      "Client 9 - Het Score: 1.0000, Initial temp: 0.0500\n",
      "Client 10 - Het Score: 0.9858, Initial temp: 0.0500\n",
      "Client 11 - Het Score: 0.8155, Initial temp: 0.0576\n",
      "Client 12 - Het Score: 0.6630, Initial temp: 0.0982\n",
      "Client 13 - Het Score: 0.5425, Initial temp: 0.1497\n",
      "Client 14 - Het Score: 0.8720, Initial temp: 0.0500\n",
      "Client 15 - Het Score: 0.7941, Initial temp: 0.0621\n",
      "Client 16 - Het Score: 0.6080, Initial temp: 0.1191\n",
      "Client 17 - Het Score: 0.9110, Initial temp: 0.0500\n",
      "Client 18 - Het Score: 0.8198, Initial temp: 0.0567\n",
      "Client 19 - Het Score: 0.8745, Initial temp: 0.0500\n",
      "\n",
      "==================== COMMUNICATION ROUND 1 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 2 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 3 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 4 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 5 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 6 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 7 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 8 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 9 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 10 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 11 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 12 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 13 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 14 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 15 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 16 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 17 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 18 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 19 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 20 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 21 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 22 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 23 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 24 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 25 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 26 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 27 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 28 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 29 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 30 ====================\n",
      "\n",
      "--- LOCAL TRAINING OF CLIENTS ---\n",
      "\n",
      "--- Client 0 Local Training with T=0.0528, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 2.4793\n",
      "Epoch [2/5], Loss: 0.7302\n",
      "Epoch [3/5], Loss: 0.4767\n",
      "Epoch [4/5], Loss: 0.3153\n",
      "Epoch [5/5], Loss: 0.2402\n",
      "Client 0 Private Data Validation Accuracy: 67.19%\n",
      "\n",
      "--- Client 1 Local Training with T=0.0500, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.4987\n",
      "Epoch [2/5], Loss: 0.1230\n",
      "Epoch [3/5], Loss: 0.0645\n",
      "Epoch [4/5], Loss: 0.0514\n",
      "Epoch [5/5], Loss: 0.0291\n",
      "Client 1 Private Data Validation Accuracy: 84.90%\n",
      "\n",
      "--- Client 2 Local Training with T=0.0500, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.5530\n",
      "Epoch [2/5], Loss: 0.2828\n",
      "Epoch [3/5], Loss: 0.2184\n",
      "Epoch [4/5], Loss: 0.2091\n",
      "Epoch [5/5], Loss: 0.1269\n",
      "Client 2 Private Data Validation Accuracy: 64.58%\n",
      "\n",
      "--- Client 3 Local Training with T=0.0548, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.6191\n",
      "Epoch [2/5], Loss: 0.3466\n",
      "Epoch [3/5], Loss: 0.7381\n",
      "Epoch [4/5], Loss: 0.1688\n",
      "Epoch [5/5], Loss: 0.1984\n",
      "Client 3 Private Data Validation Accuracy: 68.23%\n",
      "\n",
      "--- Client 4 Local Training with T=0.0541, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.6369\n",
      "Epoch [2/5], Loss: 0.4701\n",
      "Epoch [3/5], Loss: 0.3726\n",
      "Epoch [4/5], Loss: 0.1925\n",
      "Epoch [5/5], Loss: 0.1815\n",
      "Client 4 Private Data Validation Accuracy: 67.97%\n",
      "\n",
      "--- Client 5 Local Training with T=0.1024, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.5061\n",
      "Epoch [2/5], Loss: 0.3094\n",
      "Epoch [3/5], Loss: 0.2356\n",
      "Epoch [4/5], Loss: 0.2339\n",
      "Epoch [5/5], Loss: 0.1434\n",
      "Client 5 Private Data Validation Accuracy: 68.23%\n",
      "\n",
      "--- Client 6 Local Training with T=0.0531, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.6189\n",
      "Epoch [2/5], Loss: 0.2053\n",
      "Epoch [3/5], Loss: 0.1308\n",
      "Epoch [4/5], Loss: 0.0860\n",
      "Epoch [5/5], Loss: 0.0396\n",
      "Client 6 Private Data Validation Accuracy: 74.48%\n",
      "\n",
      "--- Client 7 Local Training with T=0.1812, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.3959\n",
      "Epoch [2/5], Loss: 0.1754\n",
      "Epoch [3/5], Loss: 0.1513\n",
      "Epoch [4/5], Loss: 0.1346\n",
      "Epoch [5/5], Loss: 0.0799\n",
      "Client 7 Private Data Validation Accuracy: 61.33%\n",
      "\n",
      "--- Client 8 Local Training with T=0.0557, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 1.5887\n",
      "Epoch [2/5], Loss: 0.3830\n",
      "Epoch [3/5], Loss: 0.2877\n",
      "Epoch [4/5], Loss: 0.2314\n",
      "Epoch [5/5], Loss: 0.1415\n",
      "Client 8 Private Data Validation Accuracy: 75.78%\n",
      "\n",
      "--- Client 9 Local Training with T=0.0500, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.6636\n",
      "Epoch [2/5], Loss: 0.2645\n",
      "Epoch [3/5], Loss: 0.1587\n",
      "Epoch [4/5], Loss: 0.1342\n",
      "Epoch [5/5], Loss: 0.2142\n",
      "Client 9 Private Data Validation Accuracy: 73.44%\n",
      "\n",
      "--- Client 10 Local Training with T=0.0500, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.9798\n",
      "Epoch [2/5], Loss: 0.3187\n",
      "Epoch [3/5], Loss: 0.1999\n",
      "Epoch [4/5], Loss: 0.1500\n",
      "Epoch [5/5], Loss: 0.0894\n",
      "Client 10 Private Data Validation Accuracy: 72.27%\n",
      "\n",
      "--- Client 11 Local Training with T=0.0547, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 1.2769\n",
      "Epoch [2/5], Loss: 0.2421\n",
      "Epoch [3/5], Loss: 0.1674\n",
      "Epoch [4/5], Loss: 0.1070\n",
      "Epoch [5/5], Loss: 0.0659\n",
      "Client 11 Private Data Validation Accuracy: 79.69%\n",
      "\n",
      "--- Client 12 Local Training with T=0.0886, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.5866\n",
      "Epoch [2/5], Loss: 0.1728\n",
      "Epoch [3/5], Loss: 0.1066\n",
      "Epoch [4/5], Loss: 0.1362\n",
      "Epoch [5/5], Loss: 0.1178\n",
      "Client 12 Private Data Validation Accuracy: 79.17%\n",
      "\n",
      "--- Client 13 Local Training with T=0.1284, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.6750\n",
      "Epoch [2/5], Loss: 0.4772\n",
      "Epoch [3/5], Loss: 0.2428\n",
      "Epoch [4/5], Loss: 0.1211\n",
      "Epoch [5/5], Loss: 0.1220\n",
      "Client 13 Private Data Validation Accuracy: 64.84%\n",
      "\n",
      "--- Client 14 Local Training with T=0.0500, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 2.2579\n",
      "Epoch [2/5], Loss: 0.6106\n",
      "Epoch [3/5], Loss: 0.3154\n",
      "Epoch [4/5], Loss: 0.1695\n",
      "Epoch [5/5], Loss: 0.2441\n",
      "Client 14 Private Data Validation Accuracy: 53.12%\n",
      "\n",
      "--- Client 15 Local Training with T=0.0560, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.6704\n",
      "Epoch [2/5], Loss: 0.2447\n",
      "Epoch [3/5], Loss: 0.1112\n",
      "Epoch [4/5], Loss: 0.0859\n",
      "Epoch [5/5], Loss: 0.0531\n",
      "Client 15 Private Data Validation Accuracy: 79.17%\n",
      "\n",
      "--- Client 16 Local Training with T=0.1131, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.9354\n",
      "Epoch [2/5], Loss: 0.3678\n",
      "Epoch [3/5], Loss: 0.2088\n",
      "Epoch [4/5], Loss: 0.1565\n",
      "Epoch [5/5], Loss: 0.1094\n",
      "Client 16 Private Data Validation Accuracy: 68.75%\n",
      "\n",
      "--- Client 17 Local Training with T=0.0500, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.8995\n",
      "Epoch [2/5], Loss: 0.4113\n",
      "Epoch [3/5], Loss: 0.5373\n",
      "Epoch [4/5], Loss: 0.4986\n",
      "Epoch [5/5], Loss: 0.2422\n",
      "Client 17 Private Data Validation Accuracy: 60.16%\n",
      "\n",
      "--- Client 18 Local Training with T=0.0539, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.7728\n",
      "Epoch [2/5], Loss: 0.2930\n",
      "Epoch [3/5], Loss: 0.2405\n",
      "Epoch [4/5], Loss: 0.1618\n",
      "Epoch [5/5], Loss: 0.0824\n",
      "Client 18 Private Data Validation Accuracy: 70.31%\n",
      "\n",
      "--- Client 19 Local Training with T=0.0500, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 1.5485\n",
      "Epoch [2/5], Loss: 0.7659\n",
      "Epoch [3/5], Loss: 0.3244\n",
      "Epoch [4/5], Loss: 0.1451\n",
      "Epoch [5/5], Loss: 0.1844\n",
      "Client 19 Private Data Validation Accuracy: 67.71%\n",
      "\n",
      "--- Server Evaluation ---\n",
      "Server Test Data Validation Accuracy: 56.32%\n"
     ]
    },
    {
     "ename": "NameError",
     "evalue": "name 'server_acc' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m/tmp/ipykernel_36/2023819718.py\u001b[0m in \u001b[0;36m<cell line: 0>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0mresults\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrun_experiment\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mplot_model_comparison\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresults\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      3\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      4\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresults\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/tmp/ipykernel_36/1585090999.py\u001b[0m in \u001b[0;36mplot_model_comparison\u001b[0;34m(results)\u001b[0m\n\u001b[1;32m     39\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     40\u001b[0m     \u001b[0mplt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfigure\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfigsize\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m6\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 41\u001b[0;31m     \u001b[0mplt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbar\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mwidth\u001b[0m\u001b[0;34m/\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mserver_acc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mwidth\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabel\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'Server Acc (%)'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     42\u001b[0m     \u001b[0mplt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbar\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mwidth\u001b[0m\u001b[0;34m/\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mclient_acc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mwidth\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabel\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'Average Client Acc (%)'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     43\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mNameError\u001b[0m: name 'server_acc' is not defined"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<Figure size 1000x600 with 0 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "results = run_experiment()\n",
    "plot_model_comparison(results)\n",
    "\n",
    "print(results)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "nYAg_h1n7NAt"
   },
   "source": [
    "---"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "gpuType": "T4",
   "provenance": []
  },
  "kaggle": {
   "accelerator": "nvidiaTeslaT4",
   "dataSources": [],
   "dockerImageVersionId": 31090,
   "isGpuEnabled": true,
   "isInternetEnabled": true,
   "language": "python",
   "sourceType": "notebook"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
