{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "vgUXRKpJh1yE"
   },
   "source": [
    "# FedChill - Ci\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "execution": {
     "iopub.execute_input": "2025-09-22T17:17:38.932979Z",
     "iopub.status.busy": "2025-09-22T17:17:38.932345Z",
     "iopub.status.idle": "2025-09-22T17:17:46.432127Z",
     "shell.execute_reply": "2025-09-22T17:17:46.431357Z",
     "shell.execute_reply.started": "2025-09-22T17:17:38.932956Z"
    },
    "id": "_dq_ZRdAh1yG",
    "outputId": "c38ff43a-2f89-4321-d643-5b651b9429aa",
    "trusted": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using device: cuda\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torch.nn.functional as F\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "import numpy as np\n",
    "from torch.utils.data import Dataset, DataLoader, Subset, random_split\n",
    "import matplotlib.pyplot as plt\n",
    "from collections import Counter\n",
    "import copy\n",
    "import random\n",
    "from typing import Dict, List, Tuple\n",
    "import math\n",
    "\n",
    "import pandas as pd\n",
    "torch.manual_seed(42)\n",
    "np.random.seed(42)\n",
    "random.seed(42)\n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "print(f\"Using device: {device}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "2NKqWJJTS7GB"
   },
   "source": [
    "---"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-09-22T17:17:48.470635Z",
     "iopub.status.busy": "2025-09-22T17:17:48.469778Z",
     "iopub.status.idle": "2025-09-22T17:17:48.473904Z",
     "shell.execute_reply": "2025-09-22T17:17:48.473246Z",
     "shell.execute_reply.started": "2025-09-22T17:17:48.470584Z"
    },
    "id": "hHTqcMmPeQ1c",
    "trusted": true
   },
   "outputs": [],
   "source": [
    "NUM_CLASSES = 10 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-09-22T17:29:35.213826Z",
     "iopub.status.busy": "2025-09-22T17:29:35.213295Z",
     "iopub.status.idle": "2025-09-22T17:29:35.226981Z",
     "shell.execute_reply": "2025-09-22T17:29:35.226097Z",
     "shell.execute_reply.started": "2025-09-22T17:29:35.213803Z"
    },
    "id": "gVCZKSAhh1yH",
    "trusted": true
   },
   "outputs": [],
   "source": [
    "class SimpleCNN(nn.Module):\n",
    "    def __init__(self, num_classes=NUM_CLASSES, seed=42):\n",
    "        super(SimpleCNN, self).__init__()\n",
    "\n",
    "        torch.manual_seed(seed)\n",
    "\n",
    "        self.conv1 = nn.Conv2d(3, 64, 3, padding=1)\n",
    "        self.bn1 = nn.BatchNorm2d(64)\n",
    "        self.conv2 = nn.Conv2d(64, 64, 3, padding=1)\n",
    "        self.bn2 = nn.BatchNorm2d(64)\n",
    "\n",
    "        self.conv3 = nn.Conv2d(64, 128, 3, padding=1)\n",
    "        self.bn3 = nn.BatchNorm2d(128)\n",
    "        self.conv4 = nn.Conv2d(128, 128, 3, padding=1)\n",
    "        self.bn4 = nn.BatchNorm2d(128)\n",
    "\n",
    "        self.conv5 = nn.Conv2d(128, 256, 3, padding=1)\n",
    "        self.bn5 = nn.BatchNorm2d(256)\n",
    "        self.conv6 = nn.Conv2d(256, 256, 3, padding=1)\n",
    "        self.bn6 = nn.BatchNorm2d(256)\n",
    "\n",
    "        self.conv7 = nn.Conv2d(256, 512, 3, padding=1)\n",
    "        self.bn7 = nn.BatchNorm2d(512)\n",
    "\n",
    "        self.pool = nn.MaxPool2d(2, 2)\n",
    "        self.adaptive_pool = nn.AdaptiveAvgPool2d((2, 2))  \n",
    "\n",
    "        self.fc1 = nn.Linear(512 * 2 * 2, 1024)\n",
    "        self.bn_fc1 = nn.BatchNorm1d(1024)\n",
    "        self.dropout1 = nn.Dropout(0.5)\n",
    "\n",
    "        self.fc2 = nn.Linear(1024, 512)\n",
    "        self.bn_fc2 = nn.BatchNorm1d(512)\n",
    "        self.dropout2 = nn.Dropout(0.3)\n",
    "\n",
    "        self.fc3 = nn.Linear(512, num_classes)\n",
    "\n",
    "        self._initialize_weights()\n",
    "\n",
    "        torch.manual_seed(torch.initial_seed())\n",
    "\n",
    "    def _initialize_weights(self):\n",
    "        for m in self.modules():\n",
    "            if isinstance(m, nn.Conv2d):\n",
    "                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n",
    "                if m.bias is not None:\n",
    "                    nn.init.constant_(m.bias, 0)\n",
    "            elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):\n",
    "                nn.init.constant_(m.weight, 1)\n",
    "                nn.init.constant_(m.bias, 0)\n",
    "            elif isinstance(m, nn.Linear):\n",
    "                nn.init.normal_(m.weight, 0, 0.01)\n",
    "                nn.init.constant_(m.bias, 0)\n",
    "\n",
    "\n",
    "    def forward(self, x):\n",
    "\n",
    "        x = F.relu(self.bn1(self.conv1(x)))\n",
    "        x = F.relu(self.bn2(self.conv2(x)))\n",
    "        x = self.pool(x)  \n",
    "\n",
    "        x = F.relu(self.bn3(self.conv3(x)))\n",
    "        x = F.relu(self.bn4(self.conv4(x)))\n",
    "        x = self.pool(x) \n",
    "\n",
    "        x = F.relu(self.bn5(self.conv5(x)))\n",
    "        x = F.relu(self.bn6(self.conv6(x)))\n",
    "        x = self.pool(x)  \n",
    "\n",
    "        x = F.relu(self.bn7(self.conv7(x)))\n",
    "        x = self.adaptive_pool(x) \n",
    "\n",
    "        x = x.view(x.size(0), -1)  \n",
    "\n",
    "        x = F.relu(self.bn_fc1(self.fc1(x)))\n",
    "        x = self.dropout1(x)\n",
    "\n",
    "        x = F.relu(self.bn_fc2(self.fc2(x)))\n",
    "        x = self.dropout2(x)\n",
    "\n",
    "        x = self.fc3(x)\n",
    "        return x\n",
    "\n",
    "    def get_features(self, x):\n",
    "\n",
    "        x = F.relu(self.bn1(self.conv1(x)))\n",
    "        x = F.relu(self.bn2(self.conv2(x)))\n",
    "        x = self.pool(x)\n",
    "\n",
    "\n",
    "        x = F.relu(self.bn3(self.conv3(x)))\n",
    "        x = F.relu(self.bn4(self.conv4(x)))\n",
    "        x = self.pool(x)\n",
    "\n",
    "\n",
    "        x = F.relu(self.bn5(self.conv5(x)))\n",
    "        x = F.relu(self.bn6(self.conv6(x)))\n",
    "        x = self.pool(x)\n",
    "\n",
    "\n",
    "        x = F.relu(self.bn7(self.conv7(x)))\n",
    "        x = self.adaptive_pool(x)\n",
    "\n",
    "\n",
    "        x = x.view(x.size(0), -1)\n",
    "        x = F.relu(self.bn_fc1(self.fc1(x)))\n",
    "        x = self.dropout1(x)\n",
    "        x = F.relu(self.bn_fc2(self.fc2(x)))\n",
    "\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-09-22T17:17:54.916465Z",
     "iopub.status.busy": "2025-09-22T17:17:54.915701Z",
     "iopub.status.idle": "2025-09-22T17:17:54.920914Z",
     "shell.execute_reply": "2025-09-22T17:17:54.920156Z",
     "shell.execute_reply.started": "2025-09-22T17:17:54.916434Z"
    },
    "id": "fKFX-gLRddsE",
    "trusted": true
   },
   "outputs": [],
   "source": [
    "BATCH_SIZE = 64\n",
    "LEARNING_RATE = 0.01\n",
    "LOCAL_EPOCHS = 5\n",
    "NUM_OF_CLIENTS = 10\n",
    "COMM_ROUND = 30\n",
    "ALPHA = 0.5  \n",
    "FRAC = 0.1   "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "oJrL1U1Hh1yJ"
   },
   "source": [
    "---"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-09-22T17:18:00.390631Z",
     "iopub.status.busy": "2025-09-22T17:18:00.389929Z",
     "iopub.status.idle": "2025-09-22T17:18:00.403425Z",
     "shell.execute_reply": "2025-09-22T17:18:00.402795Z",
     "shell.execute_reply.started": "2025-09-22T17:18:00.390591Z"
    },
    "id": "CImsSomMh1yH",
    "trusted": true
   },
   "outputs": [],
   "source": [
    "def load_and_partition_data(num_clients=NUM_OF_CLIENTS, alpha=ALPHA, batch_size=BATCH_SIZE, frac=FRAC, rand_seed=42):\n",
    "\n",
    "    transform = transforms.Compose([\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n",
    "    ])\n",
    "\n",
    "\n",
    "    torch.manual_seed(rand_seed)\n",
    "    np.random.seed(rand_seed)\n",
    "\n",
    "    full_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)\n",
    "    test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)\n",
    "\n",
    "    y_train = np.array(full_dataset.targets)\n",
    "    y_test = np.array(test_dataset.targets)\n",
    "\n",
    "    num_classes = NUM_CLASSES\n",
    "    N = len(full_dataset)  \n",
    "    N_test = len(test_dataset) \n",
    "\n",
    "    net_dataidx_map = {}\n",
    "    net_dataidx_map_test = {}\n",
    "\n",
    "    min_size = 0\n",
    "    while min_size < 10: \n",
    "        idx_batch = [[] for _ in range(num_clients)]\n",
    "        idx_batch_test = [[] for _ in range(num_clients)]\n",
    "        for k in range(num_classes):\n",
    "            idx_k = np.where(y_train == k)[0]\n",
    "            idx_k_test = np.where(y_test == k)[0]\n",
    "            np.random.shuffle(idx_k)  \n",
    "            np.random.shuffle(idx_k_test)  \n",
    "            proportions = np.random.dirichlet(np.repeat(alpha, num_clients))\n",
    "            proportions_train = np.array([p * (len(idx_j) < N / num_clients) for p, idx_j in zip(proportions, idx_batch)])\n",
    "            proportions_test = np.array([p * (len(idx_j) < N_test / num_clients) for p, idx_j in zip(proportions, idx_batch_test)])\n",
    "            proportions_train = proportions_train / proportions_train.sum()  \n",
    "            proportions_test = proportions_test / proportions_test.sum()  \n",
    "            proportions_train = (np.cumsum(proportions_train) * len(idx_k)).astype(int)[:-1]\n",
    "            proportions_test = (np.cumsum(proportions_test) * len(idx_k_test)).astype(int)[:-1]\n",
    "            idx_batch = [idx_j + idx.tolist() for idx_j, idx in zip(idx_batch, np.split(idx_k, proportions_train))]\n",
    "            idx_batch_test = [idx_j + idx.tolist() for idx_j, idx in zip(idx_batch_test, np.split(idx_k_test, proportions_test))]\n",
    "        min_size = min([len(idx_j) for idx_j in idx_batch])  \n",
    "\n",
    "    for j in range(num_clients):\n",
    "        np.random.shuffle(idx_batch[j])\n",
    "        np.random.shuffle(idx_batch_test[j])\n",
    "        net_dataidx_map[j] = idx_batch[j]\n",
    "        net_dataidx_map_test[j] = idx_batch_test[j]\n",
    "\n",
    "    client_train_loaders = []\n",
    "    client_val_loaders = []\n",
    "    client_class_distributions = []\n",
    "\n",
    "    for i in range(num_clients):\n",
    "        np.random.seed(rand_seed + i)\n",
    "\n",
    "        num_data = len(net_dataidx_map[i])\n",
    "        frac_num_data = int(frac * num_data)\n",
    "        frac_indices = np.random.choice(num_data, frac_num_data, replace=False)\n",
    "        train_indices = [net_dataidx_map[i][j] for j in frac_indices]\n",
    "\n",
    "        num_data_test = len(net_dataidx_map_test[i])\n",
    "        frac_num_data_test = int(min(2 * frac, 1.0) * num_data_test)  \n",
    "        frac_indices_test = np.random.choice(num_data_test, frac_num_data_test, replace=False)\n",
    "        val_indices = [net_dataidx_map_test[i][j] for j in frac_indices_test]\n",
    "\n",
    "        client_labels = [y_train[idx] for idx in train_indices]\n",
    "        class_counts = Counter(client_labels)\n",
    "        distribution = {cls: class_counts.get(cls, 0) / len(client_labels) if len(client_labels) > 0 else 0 for cls in range(num_classes)}\n",
    "        client_class_distributions.append(distribution)\n",
    "\n",
    "        client_train_dataset = Subset(full_dataset, train_indices)\n",
    "        client_val_dataset = Subset(test_dataset, val_indices)\n",
    "\n",
    "        g_train = torch.Generator().manual_seed(rand_seed + i)\n",
    "        g_val = torch.Generator().manual_seed(rand_seed + i + num_clients)\n",
    "\n",
    "        train_loader = DataLoader(client_train_dataset, batch_size=batch_size,\n",
    "                                 shuffle=True, generator=g_train, drop_last=True)\n",
    "        val_loader = DataLoader(client_val_dataset, batch_size=batch_size,\n",
    "                                shuffle=True, generator=g_val, drop_last=True)\n",
    "\n",
    "        client_train_loaders.append(train_loader)\n",
    "        client_val_loaders.append(val_loader)\n",
    "\n",
    "    g_test = torch.Generator().manual_seed(rand_seed + 2 * num_clients + 1)\n",
    "    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True,\n",
    "                             generator=g_test, num_workers=2)\n",
    "\n",
    "    print(\"Data partitioning complete.\")\n",
    "    torch.manual_seed(rand_seed)\n",
    "    np.random.seed(rand_seed)\n",
    "\n",
    "    return client_train_loaders, client_val_loaders, test_loader, client_class_distributions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "OOr1Qsu_S7GC"
   },
   "source": [
    "## Utility Functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-09-22T17:18:03.708876Z",
     "iopub.status.busy": "2025-09-22T17:18:03.708589Z",
     "iopub.status.idle": "2025-09-22T17:18:03.719052Z",
     "shell.execute_reply": "2025-09-22T17:18:03.718496Z",
     "shell.execute_reply.started": "2025-09-22T17:18:03.708857Z"
    },
    "id": "f2GMhq3Qh1yJ",
    "trusted": true
   },
   "outputs": [],
   "source": [
    "def client_train(model, train_loader, optimizer, epochs, print_flag=True):\n",
    "    model.train()\n",
    "    for epoch in range(epochs):\n",
    "        running_loss = 0.0\n",
    "        for inputs, targets in train_loader:\n",
    "            inputs, targets = inputs.to(device), targets.to(device)\n",
    "            optimizer.zero_grad()\n",
    "            outputs = model(inputs)\n",
    "            loss = F.cross_entropy(outputs, targets)\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            running_loss += loss.item()\n",
    "\n",
    "        if print_flag:\n",
    "            print(f\"Epoch [{epoch+1}/{epochs}], Loss: {running_loss/len(train_loader):.4f}\")\n",
    "\n",
    "def aggregate_models(global_model, client_models, client_indices, train_loaders):\n",
    "    global_dict = global_model.state_dict()\n",
    "\n",
    "    n_samples = [len(train_loaders[idx].dataset) for idx in client_indices]\n",
    "    total_samples = sum(n_samples)\n",
    "\n",
    "    aggregated_dict = {}\n",
    "\n",
    "    for key in global_dict.keys():\n",
    "        aggregated_dict[key] = torch.zeros_like(global_dict[key], dtype=torch.float32)\n",
    "\n",
    "    for i, idx in enumerate(client_indices):\n",
    "        client_dict = client_models[idx].state_dict()\n",
    "        weight = n_samples[i] / total_samples\n",
    "\n",
    "        for key in global_dict.keys():\n",
    "            aggregated_dict[key] += client_dict[key].float() * weight\n",
    "\n",
    "    for key in global_dict.keys():\n",
    "        global_dict[key] = aggregated_dict[key].to(dtype=global_dict[key].dtype)\n",
    "\n",
    "    global_model.load_state_dict(global_dict)\n",
    "\n",
    "\n",
    "def evaluate_model(model, data_loader):\n",
    "\n",
    "    model.eval()\n",
    "    correct = 0\n",
    "    total = 0\n",
    "    with torch.no_grad():\n",
    "        for inputs, targets in data_loader:\n",
    "            inputs, targets = inputs.to(device), targets.to(device)\n",
    "            outputs = model(inputs)\n",
    "            _, predicted = torch.max(outputs, 1)\n",
    "            total += targets.size(0)\n",
    "            correct += (predicted == targets).sum().item()\n",
    "    return 100 * correct / total\n",
    "\n",
    "def client_train_with_temp(model, train_loader, optimizer, epochs, temperature=1.0, print_flag=True):\n",
    "    model.train()\n",
    "    for epoch in range(epochs):\n",
    "        running_loss = 0.0\n",
    "        for inputs, targets in train_loader:\n",
    "            inputs, targets = inputs.to(device), targets.to(device)\n",
    "            optimizer.zero_grad()\n",
    "            outputs = model(inputs)\n",
    "\n",
    "            if temperature != 1.0:\n",
    "                log_probs = F.log_softmax(outputs / temperature, dim=1)\n",
    "                loss = F.nll_loss(log_probs, targets)\n",
    "            else:\n",
    "                loss = F.cross_entropy(outputs, targets)\n",
    "\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            running_loss += loss.item()\n",
    "\n",
    "        if print_flag:\n",
    "            print(f\"Epoch [{epoch+1}/{epochs}], Loss: {running_loss/len(train_loader):.4f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "changed temp adjust factor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-09-22T17:24:26.784955Z",
     "iopub.status.busy": "2025-09-22T17:24:26.784253Z",
     "iopub.status.idle": "2025-09-22T17:24:26.800721Z",
     "shell.execute_reply": "2025-09-22T17:24:26.799990Z",
     "shell.execute_reply.started": "2025-09-22T17:24:26.784929Z"
    },
    "id": "g0f5rDJn68sE",
    "trusted": true
   },
   "outputs": [],
   "source": [
    "class AdaptiveClientHyperparams:\n",
    "    def __init__(self, initial_lr=0.01, min_lr=0.0001, max_lr=0.05,\n",
    "                 initial_temp=0.5, min_temp=0.05, max_temp=1.0,\n",
    "                 patience=2, lr_decay_factor=0.7, temp_adjust_factor=0.95, scaling_factor=3):\n",
    "        self.lr = initial_lr\n",
    "        self.min_lr = min_lr\n",
    "        self.max_lr = max_lr\n",
    "        self.lr_decay_factor = lr_decay_factor\n",
    "        self.scaling_factor = scaling_factor\n",
    "\n",
    "        self.temp = initial_temp\n",
    "        self.min_temp = min_temp\n",
    "        self.max_temp = max_temp\n",
    "        self.temp_adjust_factor = temp_adjust_factor\n",
    "\n",
    "        self.patience = patience\n",
    "        self.performance_history = []\n",
    "        self.stagnation_count = 0\n",
    "        self.improvement_count = 0\n",
    "        self.het_score = 0.0    \n",
    "\n",
    "    def update_temp_from_heterogeneity(self, het_score):\n",
    "        self.het_score = het_score\n",
    "        self.temp = self.max_temp * math.exp(-self.scaling_factor * het_score) \n",
    "        self.temp = min(max(self.temp, self.min_temp), self.max_temp)\n",
    "\n",
    "    def record_performance(self, accuracy):\n",
    "        self.performance_history.append(accuracy)\n",
    "\n",
    "        if len(self.performance_history) >= 3:\n",
    "            if self.performance_history[-1] <= self.performance_history[-2]:\n",
    "                self.stagnation_count += 1\n",
    "                self.improvement_count = 0\n",
    "            else:\n",
    "                self.improvement_count += 1\n",
    "                self.stagnation_count = 0\n",
    "\n",
    "            if self.stagnation_count >= self.patience:\n",
    "                if self.temp > self.min_temp * 1.1:  \n",
    "                    self.temp *= self.temp_adjust_factor\n",
    "                    self.temp = max(self.temp, self.min_temp)\n",
    "                    self.stagnation_count = 0\n",
    "                    return True, \"decrease-temp\", self.temp\n",
    "\n",
    "        return False, \"unchanged\", None\n",
    "\n",
    "    def get_lr(self):\n",
    "        return self.lr\n",
    "\n",
    "    def get_temp(self):\n",
    "        return self.temp\n",
    "\n",
    "\n",
    "def calculate_heterogeneity_score(client_distribution, global_distribution):\n",
    "    score = 0\n",
    "    for cls in range(10):\n",
    "        client_prob = client_distribution.get(cls, 0)\n",
    "        global_prob = global_distribution.get(cls, 1/10)\n",
    "        if client_prob > 0 and global_prob > 0:\n",
    "            ratio = client_prob / global_prob\n",
    "            score += abs(ratio - 1)\n",
    "\n",
    "    score = min(score / 10, 1)\n",
    "    return score\n",
    "\n",
    "def calculate_global_distribution(client_class_distributions):\n",
    "    global_dist = {}\n",
    "    n_clients = len(client_class_distributions)\n",
    "\n",
    "    for cls in range(10):\n",
    "        global_dist[cls] = sum(dist.get(cls, 0) for dist in client_class_distributions) / n_clients\n",
    "\n",
    "    for cls in range(10):\n",
    "        if global_dist[cls] < 0.01:\n",
    "            global_dist[cls] = 0.01\n",
    "\n",
    "    total = sum(global_dist.values())\n",
    "    global_dist = {k: v/total for k, v in global_dist.items()}\n",
    "\n",
    "    return global_dist\n",
    "\n",
    "def client_train(model, train_loader, optimizer, epochs, print_flag=True):\n",
    "    model.train()\n",
    "    for epoch in range(epochs):\n",
    "        running_loss = 0.0\n",
    "        for inputs, targets in train_loader:\n",
    "            inputs, targets = inputs.to(device), targets.to(device)\n",
    "            optimizer.zero_grad()\n",
    "            outputs = model(inputs)\n",
    "            loss = F.cross_entropy(outputs, targets)\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            running_loss += loss.item()\n",
    "\n",
    "        if print_flag:\n",
    "            print(f\"Epoch [{epoch+1}/{epochs}], Loss: {running_loss/len(train_loader):.4f}\")\n",
    "\n",
    "def aggregate_models(global_model, client_models, client_indices, train_loaders):\n",
    "    \n",
    "    global_dict = global_model.state_dict()\n",
    "\n",
    "    n_samples = [len(train_loaders[idx].dataset) for idx in client_indices]\n",
    "    total_samples = sum(n_samples)\n",
    "\n",
    "    aggregated_dict = {}\n",
    "\n",
    "    for key in global_dict.keys():\n",
    "        aggregated_dict[key] = torch.zeros_like(global_dict[key], dtype=torch.float32)\n",
    "\n",
    "    for i, idx in enumerate(client_indices):\n",
    "        client_dict = client_models[idx].state_dict()\n",
    "        weight = n_samples[i] / total_samples\n",
    "\n",
    "        for key in global_dict.keys():\n",
    "            aggregated_dict[key] += client_dict[key].float() * weight\n",
    "\n",
    "    for key in global_dict.keys():\n",
    "        global_dict[key] = aggregated_dict[key].to(dtype=global_dict[key].dtype)\n",
    "\n",
    "    global_model.load_state_dict(global_dict)\n",
    "\n",
    "\n",
    "def evaluate_model(model, data_loader):\n",
    "\n",
    "    model.eval()\n",
    "    correct = 0\n",
    "    total = 0\n",
    "    with torch.no_grad():\n",
    "        for inputs, targets in data_loader:\n",
    "            inputs, targets = inputs.to(device), targets.to(device)\n",
    "            outputs = model(inputs)\n",
    "            _, predicted = torch.max(outputs, 1)\n",
    "            total += targets.size(0)\n",
    "            correct += (predicted == targets).sum().item()\n",
    "    return 100 * correct / total\n",
    "\n",
    "def client_train_with_temp(model, train_loader, optimizer, epochs, temperature=1.0, print_flag=True):\n",
    "    model.train()\n",
    "    for epoch in range(epochs):\n",
    "        running_loss = 0.0\n",
    "        for inputs, targets in train_loader:\n",
    "            inputs, targets = inputs.to(device), targets.to(device)\n",
    "            optimizer.zero_grad()\n",
    "            outputs = model(inputs)\n",
    "\n",
    "            if temperature != 1.0:\n",
    "                log_probs = F.log_softmax(outputs / temperature, dim=1)\n",
    "                loss = F.nll_loss(log_probs, targets)\n",
    "            else:\n",
    "                loss = F.cross_entropy(outputs, targets)\n",
    "\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            running_loss += loss.item()\n",
    "\n",
    "        if print_flag:\n",
    "            print(f\"Epoch [{epoch+1}/{epochs}], Loss: {running_loss/len(train_loader):.4f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "1-V6QoWgA2hB"
   },
   "source": [
    "---"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "6VOgsXc-A2hB"
   },
   "source": [
    "**HYPERPARAMETERS**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-09-22T17:24:29.127502Z",
     "iopub.status.busy": "2025-09-22T17:24:29.126786Z",
     "iopub.status.idle": "2025-09-22T17:24:29.131521Z",
     "shell.execute_reply": "2025-09-22T17:24:29.130641Z",
     "shell.execute_reply.started": "2025-09-22T17:24:29.127476Z"
    },
    "id": "RlQUq8Ht61cg",
    "trusted": true
   },
   "outputs": [],
   "source": [
    "BATCH_SIZE = 64\n",
    "LEARNING_RATE = 0.01\n",
    "LOCAL_EPOCHS = 5\n",
    "NUM_OF_CLIENTS = 10\n",
    "COMM_ROUND = 30\n",
    "ALPHA = 0.5  \n",
    "FRAC = 0.1   "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-09-22T17:29:42.278511Z",
     "iopub.status.busy": "2025-09-22T17:29:42.277972Z",
     "iopub.status.idle": "2025-09-22T17:29:42.285083Z",
     "shell.execute_reply": "2025-09-22T17:29:42.284422Z",
     "shell.execute_reply.started": "2025-09-22T17:29:42.278491Z"
    },
    "id": "48RvxVUtA2hB",
    "trusted": true
   },
   "outputs": [],
   "source": [
    "def run_experiment():\n",
    "    print(f\"\\n\\n{'='*80}\")\n",
    "    print(f\"{'='*80}\\n\")\n",
    "\n",
    "    model_results = {}\n",
    "\n",
    "    scaling_fact = np.arange(0.5, 4.5, 0.5).tolist()\n",
    "\n",
    "    for sf in scaling_fact:\n",
    "        print(f\"\\n\\n{'#'*50}\")\n",
    "        print(f\"# RUNNING FedChill with Scaling Factor {sf}\")\n",
    "        print(f\"{'#'*50}\\n\")\n",
    "\n",
    "        server_acc, client_acc, final_model = run_fedchill(scaling_factor=sf)\n",
    "\n",
    "        avg_client_acc = np.mean(client_acc, axis=0).tolist()\n",
    "        server_acc = server_acc[-1]\n",
    "\n",
    "        model_results[sf] = {\n",
    "            'server_acc': server_acc,\n",
    "            'client_acc': client_acc,\n",
    "            'final_model': final_model\n",
    "        }\n",
    "\n",
    "\n",
    "    return model_results\n",
    "\n",
    "def plot_model_comparison(results):\n",
    "    \"\"\"Bar chart: Server vs Client accuracy across scaling factors.\"\"\"\n",
    "    scaling_factors = list(results.keys())\n",
    "\n",
    "    server_acc = [data['server_acc'][-1] for data in results.values()]\n",
    "    client_acc = [np.mean(data['client_acc'][-1]) if isinstance(data['client_acc'][-1], (list, np.ndarray))\n",
    "                  else data['client_acc'][-1] for data in results.values()]\n",
    "\n",
    "    x = np.arange(len(scaling_factors))  \n",
    "    width = 0.35  \n",
    "\n",
    "    plt.figure(figsize=(10, 6))\n",
    "    plt.bar(x - width/2, server_acc, width, label='Server Acc (%)')\n",
    "    plt.bar(x + width/2, client_acc, width, label='Average Client Acc (%)')\n",
    "\n",
    "    plt.xticks(x, [str(sf) for sf in scaling_factors])\n",
    "    plt.xlabel(\"Scaling Factor\")\n",
    "    plt.ylabel(\"Accuracy (%)\")\n",
    "    plt.title(\"Server vs Client Accuracy by Scaling Factor\")\n",
    "    plt.legend()\n",
    "    plt.grid(axis='y', linestyle='--', alpha=0.7)\n",
    "\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(\"scaling_factor_comparison.png\")\n",
    "    plt.show()\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-09-22T17:29:44.469247Z",
     "iopub.status.busy": "2025-09-22T17:29:44.468733Z",
     "iopub.status.idle": "2025-09-22T17:29:44.479321Z",
     "shell.execute_reply": "2025-09-22T17:29:44.478597Z",
     "shell.execute_reply.started": "2025-09-22T17:29:44.469223Z"
    },
    "id": "ZGkxcCH6A2hC",
    "trusted": true
   },
   "outputs": [],
   "source": [
    "def run_fedchill(scaling_factor=3, model_class=SimpleCNN):\n",
    "    print_flag = False\n",
    "\n",
    "    client_train_loaders, client_val_loaders, test_loader, client_class_distributions = load_and_partition_data(\n",
    "        num_clients=NUM_OF_CLIENTS,\n",
    "        alpha=ALPHA,\n",
    "        batch_size=BATCH_SIZE,\n",
    "        frac=FRAC,\n",
    "        rand_seed=42\n",
    "    )\n",
    "    num_clients = len(client_train_loaders)\n",
    "\n",
    "    global_distribution = calculate_global_distribution(client_class_distributions)\n",
    "\n",
    "    client_params = [AdaptiveClientHyperparams(scaling_factor=scaling_factor) for _ in range(num_clients)]\n",
    "\n",
    "    for i, distribution in enumerate(client_class_distributions):\n",
    "        het_score = calculate_heterogeneity_score(distribution, global_distribution)\n",
    "        client_params[i].update_temp_from_heterogeneity(het_score)\n",
    "        print(f\"Client {i} - Het Score: {het_score:.4f}, Initial temp: {client_params[i].get_temp():.4f}\")\n",
    "\n",
    "    global_model = model_class().to(device)\n",
    "\n",
    "    round_server_acc_list = []\n",
    "    round_client_acc_list = [[] for _ in range(num_clients)]\n",
    "\n",
    "    client_temp_history = [[] for _ in range(num_clients)]\n",
    "    client_lr_history = [[] for _ in range(num_clients)]\n",
    "\n",
    "    client_models = [copy.deepcopy(global_model).to(device) for _ in range(num_clients)]\n",
    "    client_optimizers = [optim.SGD(model.parameters(), lr=client_params[i].get_lr()) for i, model in enumerate(client_models)]\n",
    "\n",
    "    for comm_round in range(COMM_ROUND):\n",
    "        print(f\"\\n{'='*20} COMMUNICATION ROUND {comm_round+1} {'='*20}\")\n",
    "\n",
    "        selected_clients = list(range(num_clients))\n",
    "\n",
    "        if comm_round == COMM_ROUND - 1:\n",
    "            print_flag = True\n",
    "            print(\"\\n--- LOCAL TRAINING OF CLIENTS ---\")\n",
    "\n",
    "        for client_idx in selected_clients:\n",
    "            client_models[client_idx].load_state_dict(global_model.state_dict())\n",
    "\n",
    "            curr_lr = client_params[client_idx].get_lr()\n",
    "            curr_temp = client_params[client_idx].get_temp()\n",
    "\n",
    "            client_temp_history[client_idx].append(curr_temp)\n",
    "            client_lr_history[client_idx].append(curr_lr)\n",
    "\n",
    "            client_optimizers[client_idx] = optim.SGD(client_models[client_idx].parameters(), lr=curr_lr)\n",
    "\n",
    "            if comm_round == COMM_ROUND - 1:\n",
    "                print(f\"\\n--- Client {client_idx} Local Training with T={curr_temp:.4f}, LR={curr_lr:.6f} ---\")\n",
    "\n",
    "            client_train_with_temp(\n",
    "                client_models[client_idx],\n",
    "                client_train_loaders[client_idx],\n",
    "                client_optimizers[client_idx],\n",
    "                LOCAL_EPOCHS,\n",
    "                temperature=curr_temp,\n",
    "                print_flag=print_flag\n",
    "            )\n",
    "\n",
    "            local_acc = evaluate_model(client_models[client_idx], client_val_loaders[client_idx])\n",
    "            if comm_round == COMM_ROUND - 1:\n",
    "                print(f\"Client {client_idx} Private Data Validation Accuracy: {local_acc:.2f}%\")\n",
    "\n",
    "            client_testset_acc = evaluate_model(client_models[client_idx], test_loader)\n",
    "            round_client_acc_list[client_idx].append(local_acc)\n",
    "\n",
    "            adjusted, param_type, new_value = client_params[client_idx].record_performance(local_acc)\n",
    "            if adjusted and (comm_round == COMM_ROUND - 1 or print_flag):\n",
    "                print(f\"Client {client_idx} - {param_type} to {new_value:.6f}\")\n",
    "\n",
    "        aggregate_models(global_model, client_models, selected_clients, client_train_loaders)\n",
    "\n",
    "        server_test_val_acc = evaluate_model(global_model, test_loader)\n",
    "        if comm_round == COMM_ROUND - 1:\n",
    "            print(f\"\\n--- Server Evaluation ---\")\n",
    "            print(f\"Server Test Data Validation Accuracy: {server_test_val_acc:.2f}%\")\n",
    "\n",
    "        round_server_acc_list.append(server_test_val_acc)\n",
    "\n",
    "    return round_server_acc_list, round_client_acc_list, global_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-09-22T17:29:48.508509Z",
     "iopub.status.busy": "2025-09-22T17:29:48.507772Z",
     "iopub.status.idle": "2025-09-22T19:30:59.329407Z",
     "shell.execute_reply": "2025-09-22T19:30:59.328083Z",
     "shell.execute_reply.started": "2025-09-22T17:29:48.508486Z"
    },
    "trusted": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "================================================================================\n",
      "================================================================================\n",
      "\n",
      "\n",
      "\n",
      "##################################################\n",
      "# RUNNING FedChill with Scaling Factor 0.5\n",
      "##################################################\n",
      "\n",
      "Data partitioning complete.\n",
      "Client 0 - Het Score: 0.9007, Initial temp: 0.6374\n",
      "Client 1 - Het Score: 0.7295, Initial temp: 0.6944\n",
      "Client 2 - Het Score: 0.8682, Initial temp: 0.6478\n",
      "Client 3 - Het Score: 0.8552, Initial temp: 0.6521\n",
      "Client 4 - Het Score: 0.5698, Initial temp: 0.7521\n",
      "Client 5 - Het Score: 0.5956, Initial temp: 0.7424\n",
      "Client 6 - Het Score: 0.7451, Initial temp: 0.6890\n",
      "Client 7 - Het Score: 0.6556, Initial temp: 0.7205\n",
      "Client 8 - Het Score: 1.0000, Initial temp: 0.6065\n",
      "Client 9 - Het Score: 1.0000, Initial temp: 0.6065\n",
      "\n",
      "==================== COMMUNICATION ROUND 1 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 2 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 3 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 4 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 5 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 6 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 7 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 8 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 9 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 10 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 11 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 12 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 13 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 14 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 15 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 16 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 17 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 18 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 19 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 20 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 21 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 22 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 23 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 24 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 25 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 26 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 27 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 28 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 29 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 30 ====================\n",
      "\n",
      "--- LOCAL TRAINING OF CLIENTS ---\n",
      "\n",
      "--- Client 0 Local Training with T=0.5192, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.1574\n",
      "Epoch [2/5], Loss: 0.0290\n",
      "Epoch [3/5], Loss: 0.0168\n",
      "Epoch [4/5], Loss: 0.0139\n",
      "Epoch [5/5], Loss: 0.0129\n",
      "Client 0 Private Data Validation Accuracy: 59.69%\n",
      "\n",
      "--- Client 1 Local Training with T=0.5373, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 1.0616\n",
      "Epoch [2/5], Loss: 0.2988\n",
      "Epoch [3/5], Loss: 0.1345\n",
      "Epoch [4/5], Loss: 0.1150\n",
      "Epoch [5/5], Loss: 0.0654\n",
      "Client 1 Private Data Validation Accuracy: 71.88%\n",
      "\n",
      "--- Client 2 Local Training with T=0.5277, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.4897\n",
      "Epoch [2/5], Loss: 0.1679\n",
      "Epoch [3/5], Loss: 0.0889\n",
      "Epoch [4/5], Loss: 0.0596\n",
      "Epoch [5/5], Loss: 0.0221\n",
      "Client 2 Private Data Validation Accuracy: 69.27%\n",
      "\n",
      "--- Client 3 Local Training with T=0.5591, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.3783\n",
      "Epoch [2/5], Loss: 0.0400\n",
      "Epoch [3/5], Loss: 0.0238\n",
      "Epoch [4/5], Loss: 0.0195\n",
      "Epoch [5/5], Loss: 0.0097\n",
      "Client 3 Private Data Validation Accuracy: 90.10%\n",
      "Client 3 - decrease-temp to 0.531130\n",
      "\n",
      "--- Client 4 Local Training with T=0.6126, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.1667\n",
      "Epoch [2/5], Loss: 0.0720\n",
      "Epoch [3/5], Loss: 0.0414\n",
      "Epoch [4/5], Loss: 0.0246\n",
      "Epoch [5/5], Loss: 0.0210\n",
      "Client 4 Private Data Validation Accuracy: 55.73%\n",
      "Client 4 - decrease-temp to 0.581944\n",
      "\n",
      "--- Client 5 Local Training with T=0.5745, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.3892\n",
      "Epoch [2/5], Loss: 0.1545\n",
      "Epoch [3/5], Loss: 0.0774\n",
      "Epoch [4/5], Loss: 0.0583\n",
      "Epoch [5/5], Loss: 0.0455\n",
      "Client 5 Private Data Validation Accuracy: 57.81%\n",
      "\n",
      "--- Client 6 Local Training with T=0.5065, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 1.2823\n",
      "Epoch [2/5], Loss: 0.3382\n",
      "Epoch [3/5], Loss: 0.1400\n",
      "Epoch [4/5], Loss: 0.0723\n",
      "Epoch [5/5], Loss: 0.0405\n",
      "Client 6 Private Data Validation Accuracy: 60.94%\n",
      "\n",
      "--- Client 7 Local Training with T=0.5575, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.1284\n",
      "Epoch [2/5], Loss: 0.0473\n",
      "Epoch [3/5], Loss: 0.0328\n",
      "Epoch [4/5], Loss: 0.0291\n",
      "Epoch [5/5], Loss: 0.0231\n",
      "Client 7 Private Data Validation Accuracy: 65.10%\n",
      "\n",
      "--- Client 8 Local Training with T=0.4459, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.2232\n",
      "Epoch [2/5], Loss: 0.0274\n",
      "Epoch [3/5], Loss: 0.0262\n",
      "Epoch [4/5], Loss: 0.0140\n",
      "Epoch [5/5], Loss: 0.0183\n",
      "Client 8 Private Data Validation Accuracy: 80.08%\n",
      "\n",
      "--- Client 9 Local Training with T=0.4693, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.4335\n",
      "Epoch [2/5], Loss: 0.1390\n",
      "Epoch [3/5], Loss: 0.0530\n",
      "Epoch [4/5], Loss: 0.0283\n",
      "Epoch [5/5], Loss: 0.0154\n",
      "Client 9 Private Data Validation Accuracy: 63.28%\n",
      "\n",
      "--- Server Evaluation ---\n",
      "Server Test Data Validation Accuracy: 49.42%\n",
      "\n",
      "\n",
      "##################################################\n",
      "# RUNNING FedChill with Scaling Factor 1.0\n",
      "##################################################\n",
      "\n",
      "Data partitioning complete.\n",
      "Client 0 - Het Score: 0.9007, Initial temp: 0.4063\n",
      "Client 1 - Het Score: 0.7295, Initial temp: 0.4821\n",
      "Client 2 - Het Score: 0.8682, Initial temp: 0.4197\n",
      "Client 3 - Het Score: 0.8552, Initial temp: 0.4252\n",
      "Client 4 - Het Score: 0.5698, Initial temp: 0.5656\n",
      "Client 5 - Het Score: 0.5956, Initial temp: 0.5512\n",
      "Client 6 - Het Score: 0.7451, Initial temp: 0.4747\n",
      "Client 7 - Het Score: 0.6556, Initial temp: 0.5191\n",
      "Client 8 - Het Score: 1.0000, Initial temp: 0.3679\n",
      "Client 9 - Het Score: 1.0000, Initial temp: 0.3679\n",
      "\n",
      "==================== COMMUNICATION ROUND 1 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 2 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 3 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 4 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 5 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 6 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 7 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 8 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 9 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 10 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 11 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 12 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 13 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 14 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 15 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 16 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 17 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 18 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 19 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 20 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 21 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 22 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 23 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 24 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 25 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 26 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 27 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 28 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 29 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 30 ====================\n",
      "\n",
      "--- LOCAL TRAINING OF CLIENTS ---\n",
      "\n",
      "--- Client 0 Local Training with T=0.3483, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.1419\n",
      "Epoch [2/5], Loss: 0.0214\n",
      "Epoch [3/5], Loss: 0.0161\n",
      "Epoch [4/5], Loss: 0.0100\n",
      "Epoch [5/5], Loss: 0.0109\n",
      "Client 0 Private Data Validation Accuracy: 58.12%\n",
      "Client 0 - decrease-temp to 0.330922\n",
      "\n",
      "--- Client 1 Local Training with T=0.3927, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.8009\n",
      "Epoch [2/5], Loss: 0.2289\n",
      "Epoch [3/5], Loss: 0.1642\n",
      "Epoch [4/5], Loss: 0.0799\n",
      "Epoch [5/5], Loss: 0.0479\n",
      "Client 1 Private Data Validation Accuracy: 70.31%\n",
      "\n",
      "--- Client 2 Local Training with T=0.3419, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.3928\n",
      "Epoch [2/5], Loss: 0.0605\n",
      "Epoch [3/5], Loss: 0.0219\n",
      "Epoch [4/5], Loss: 0.0266\n",
      "Epoch [5/5], Loss: 0.0162\n",
      "Client 2 Private Data Validation Accuracy: 70.83%\n",
      "\n",
      "--- Client 3 Local Training with T=0.3126, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.3103\n",
      "Epoch [2/5], Loss: 0.0248\n",
      "Epoch [3/5], Loss: 0.0134\n",
      "Epoch [4/5], Loss: 0.0075\n",
      "Epoch [5/5], Loss: 0.0031\n",
      "Client 3 Private Data Validation Accuracy: 89.58%\n",
      "\n",
      "--- Client 4 Local Training with T=0.5373, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.1813\n",
      "Epoch [2/5], Loss: 0.0633\n",
      "Epoch [3/5], Loss: 0.0412\n",
      "Epoch [4/5], Loss: 0.0249\n",
      "Epoch [5/5], Loss: 0.0228\n",
      "Client 4 Private Data Validation Accuracy: 59.90%\n",
      "Client 4 - decrease-temp to 0.510473\n",
      "\n",
      "--- Client 5 Local Training with T=0.4052, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.4354\n",
      "Epoch [2/5], Loss: 0.1589\n",
      "Epoch [3/5], Loss: 0.0694\n",
      "Epoch [4/5], Loss: 0.0454\n",
      "Epoch [5/5], Loss: 0.0359\n",
      "Client 5 Private Data Validation Accuracy: 53.12%\n",
      "Client 5 - decrease-temp to 0.384937\n",
      "\n",
      "--- Client 6 Local Training with T=0.3673, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 1.1713\n",
      "Epoch [2/5], Loss: 0.2912\n",
      "Epoch [3/5], Loss: 0.1365\n",
      "Epoch [4/5], Loss: 0.0628\n",
      "Epoch [5/5], Loss: 0.0230\n",
      "Client 6 Private Data Validation Accuracy: 67.19%\n",
      "\n",
      "--- Client 7 Local Training with T=0.4451, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.1273\n",
      "Epoch [2/5], Loss: 0.0473\n",
      "Epoch [3/5], Loss: 0.0328\n",
      "Epoch [4/5], Loss: 0.0243\n",
      "Epoch [5/5], Loss: 0.0207\n",
      "Client 7 Private Data Validation Accuracy: 65.62%\n",
      "\n",
      "--- Client 8 Local Training with T=0.2847, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.1584\n",
      "Epoch [2/5], Loss: 0.0112\n",
      "Epoch [3/5], Loss: 0.0160\n",
      "Epoch [4/5], Loss: 0.0165\n",
      "Epoch [5/5], Loss: 0.0481\n",
      "Client 8 Private Data Validation Accuracy: 76.95%\n",
      "Client 8 - decrease-temp to 0.270425\n",
      "\n",
      "--- Client 9 Local Training with T=0.2996, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.3693\n",
      "Epoch [2/5], Loss: 0.0991\n",
      "Epoch [3/5], Loss: 0.0505\n",
      "Epoch [4/5], Loss: 0.0291\n",
      "Epoch [5/5], Loss: 0.0131\n",
      "Client 9 Private Data Validation Accuracy: 64.84%\n",
      "\n",
      "--- Server Evaluation ---\n",
      "Server Test Data Validation Accuracy: 49.81%\n",
      "\n",
      "\n",
      "##################################################\n",
      "# RUNNING FedChill with Scaling Factor 1.5\n",
      "##################################################\n",
      "\n",
      "Data partitioning complete.\n",
      "Client 0 - Het Score: 0.9007, Initial temp: 0.2590\n",
      "Client 1 - Het Score: 0.7295, Initial temp: 0.3348\n",
      "Client 2 - Het Score: 0.8682, Initial temp: 0.2719\n",
      "Client 3 - Het Score: 0.8552, Initial temp: 0.2773\n",
      "Client 4 - Het Score: 0.5698, Initial temp: 0.4254\n",
      "Client 5 - Het Score: 0.5956, Initial temp: 0.4092\n",
      "Client 6 - Het Score: 0.7451, Initial temp: 0.3270\n",
      "Client 7 - Het Score: 0.6556, Initial temp: 0.3740\n",
      "Client 8 - Het Score: 1.0000, Initial temp: 0.2231\n",
      "Client 9 - Het Score: 1.0000, Initial temp: 0.2231\n",
      "\n",
      "==================== COMMUNICATION ROUND 1 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 2 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 3 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 4 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 5 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 6 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 7 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 8 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 9 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 10 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 11 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 12 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 13 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 14 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 15 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 16 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 17 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 18 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 19 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 20 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 21 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 22 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 23 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 24 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 25 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 26 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 27 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 28 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 29 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 30 ====================\n",
      "\n",
      "--- LOCAL TRAINING OF CLIENTS ---\n",
      "\n",
      "--- Client 0 Local Training with T=0.2220, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.1001\n",
      "Epoch [2/5], Loss: 0.0234\n",
      "Epoch [3/5], Loss: 0.0175\n",
      "Epoch [4/5], Loss: 0.0223\n",
      "Epoch [5/5], Loss: 0.0321\n",
      "Client 0 Private Data Validation Accuracy: 61.25%\n",
      "\n",
      "--- Client 1 Local Training with T=0.2461, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.8814\n",
      "Epoch [2/5], Loss: 0.1977\n",
      "Epoch [3/5], Loss: 0.1026\n",
      "Epoch [4/5], Loss: 0.0528\n",
      "Epoch [5/5], Loss: 0.0323\n",
      "Client 1 Private Data Validation Accuracy: 73.44%\n",
      "\n",
      "--- Client 2 Local Training with T=0.2215, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.3655\n",
      "Epoch [2/5], Loss: 0.1110\n",
      "Epoch [3/5], Loss: 0.1327\n",
      "Epoch [4/5], Loss: 0.0513\n",
      "Epoch [5/5], Loss: 0.0419\n",
      "Client 2 Private Data Validation Accuracy: 69.79%\n",
      "\n",
      "--- Client 3 Local Training with T=0.2377, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.2575\n",
      "Epoch [2/5], Loss: 0.0121\n",
      "Epoch [3/5], Loss: 0.0077\n",
      "Epoch [4/5], Loss: 0.0078\n",
      "Epoch [5/5], Loss: 0.0024\n",
      "Client 3 Private Data Validation Accuracy: 90.10%\n",
      "\n",
      "--- Client 4 Local Training with T=0.3465, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.1256\n",
      "Epoch [2/5], Loss: 0.0532\n",
      "Epoch [3/5], Loss: 0.0334\n",
      "Epoch [4/5], Loss: 0.0179\n",
      "Epoch [5/5], Loss: 0.0165\n",
      "Client 4 Private Data Validation Accuracy: 59.90%\n",
      "Client 4 - decrease-temp to 0.329160\n",
      "\n",
      "--- Client 5 Local Training with T=0.3167, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.3507\n",
      "Epoch [2/5], Loss: 0.1322\n",
      "Epoch [3/5], Loss: 0.0585\n",
      "Epoch [4/5], Loss: 0.0403\n",
      "Epoch [5/5], Loss: 0.0286\n",
      "Client 5 Private Data Validation Accuracy: 56.25%\n",
      "\n",
      "--- Client 6 Local Training with T=0.2664, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.8469\n",
      "Epoch [2/5], Loss: 0.1257\n",
      "Epoch [3/5], Loss: 0.0840\n",
      "Epoch [4/5], Loss: 0.0321\n",
      "Epoch [5/5], Loss: 0.0256\n",
      "Client 6 Private Data Validation Accuracy: 62.50%\n",
      "\n",
      "--- Client 7 Local Training with T=0.3376, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.1173\n",
      "Epoch [2/5], Loss: 0.0371\n",
      "Epoch [3/5], Loss: 0.0264\n",
      "Epoch [4/5], Loss: 0.0219\n",
      "Epoch [5/5], Loss: 0.0245\n",
      "Client 7 Private Data Validation Accuracy: 66.67%\n",
      "\n",
      "--- Client 8 Local Training with T=0.1727, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.1496\n",
      "Epoch [2/5], Loss: 0.0473\n",
      "Epoch [3/5], Loss: 0.0107\n",
      "Epoch [4/5], Loss: 0.0265\n",
      "Epoch [5/5], Loss: 0.0949\n",
      "Client 8 Private Data Validation Accuracy: 76.95%\n",
      "\n",
      "--- Client 9 Local Training with T=0.1558, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.2617\n",
      "Epoch [2/5], Loss: 0.0685\n",
      "Epoch [3/5], Loss: 0.0503\n",
      "Epoch [4/5], Loss: 0.0222\n",
      "Epoch [5/5], Loss: 0.0386\n",
      "Client 9 Private Data Validation Accuracy: 60.16%\n",
      "\n",
      "--- Server Evaluation ---\n",
      "Server Test Data Validation Accuracy: 51.06%\n",
      "\n",
      "\n",
      "##################################################\n",
      "# RUNNING FedChill with Scaling Factor 2.0\n",
      "##################################################\n",
      "\n",
      "Data partitioning complete.\n",
      "Client 0 - Het Score: 0.9007, Initial temp: 0.1651\n",
      "Client 1 - Het Score: 0.7295, Initial temp: 0.2325\n",
      "Client 2 - Het Score: 0.8682, Initial temp: 0.1762\n",
      "Client 3 - Het Score: 0.8552, Initial temp: 0.1808\n",
      "Client 4 - Het Score: 0.5698, Initial temp: 0.3199\n",
      "Client 5 - Het Score: 0.5956, Initial temp: 0.3038\n",
      "Client 6 - Het Score: 0.7451, Initial temp: 0.2253\n",
      "Client 7 - Het Score: 0.6556, Initial temp: 0.2695\n",
      "Client 8 - Het Score: 1.0000, Initial temp: 0.1353\n",
      "Client 9 - Het Score: 1.0000, Initial temp: 0.1353\n",
      "\n",
      "==================== COMMUNICATION ROUND 1 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 2 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 3 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 4 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 5 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 6 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 7 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 8 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 9 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 10 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 11 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 12 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 13 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 14 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 15 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 16 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 17 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 18 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 19 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 20 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 21 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 22 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 23 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 24 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 25 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 26 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 27 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 28 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 29 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 30 ====================\n",
      "\n",
      "--- LOCAL TRAINING OF CLIENTS ---\n",
      "\n",
      "--- Client 0 Local Training with T=0.1415, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.1223\n",
      "Epoch [2/5], Loss: 0.0293\n",
      "Epoch [3/5], Loss: 0.0281\n",
      "Epoch [4/5], Loss: 0.0479\n",
      "Epoch [5/5], Loss: 0.0375\n",
      "Client 0 Private Data Validation Accuracy: 61.88%\n",
      "\n",
      "--- Client 1 Local Training with T=0.2098, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.8251\n",
      "Epoch [2/5], Loss: 0.1949\n",
      "Epoch [3/5], Loss: 0.0613\n",
      "Epoch [4/5], Loss: 0.0427\n",
      "Epoch [5/5], Loss: 0.0390\n",
      "Client 1 Private Data Validation Accuracy: 73.44%\n",
      "\n",
      "--- Client 2 Local Training with T=0.1510, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.3852\n",
      "Epoch [2/5], Loss: 0.0411\n",
      "Epoch [3/5], Loss: 0.0155\n",
      "Epoch [4/5], Loss: 0.0545\n",
      "Epoch [5/5], Loss: 0.0661\n",
      "Client 2 Private Data Validation Accuracy: 65.10%\n",
      "\n",
      "--- Client 3 Local Training with T=0.1263, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.3443\n",
      "Epoch [2/5], Loss: 0.0092\n",
      "Epoch [3/5], Loss: 0.0310\n",
      "Epoch [4/5], Loss: 0.0397\n",
      "Epoch [5/5], Loss: 0.0015\n",
      "Client 3 Private Data Validation Accuracy: 91.15%\n",
      "\n",
      "--- Client 4 Local Training with T=0.2743, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.1419\n",
      "Epoch [2/5], Loss: 0.0381\n",
      "Epoch [3/5], Loss: 0.0224\n",
      "Epoch [4/5], Loss: 0.0175\n",
      "Epoch [5/5], Loss: 0.0175\n",
      "Client 4 Private Data Validation Accuracy: 63.02%\n",
      "\n",
      "--- Client 5 Local Training with T=0.2475, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.3479\n",
      "Epoch [2/5], Loss: 0.1394\n",
      "Epoch [3/5], Loss: 0.0727\n",
      "Epoch [4/5], Loss: 0.0410\n",
      "Epoch [5/5], Loss: 0.0316\n",
      "Client 5 Private Data Validation Accuracy: 63.28%\n",
      "\n",
      "--- Client 6 Local Training with T=0.1743, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.8450\n",
      "Epoch [2/5], Loss: 0.1500\n",
      "Epoch [3/5], Loss: 0.0783\n",
      "Epoch [4/5], Loss: 0.0621\n",
      "Epoch [5/5], Loss: 0.0204\n",
      "Client 6 Private Data Validation Accuracy: 62.50%\n",
      "Client 6 - decrease-temp to 0.165629\n",
      "\n",
      "--- Client 7 Local Training with T=0.2195, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.1279\n",
      "Epoch [2/5], Loss: 0.0439\n",
      "Epoch [3/5], Loss: 0.0255\n",
      "Epoch [4/5], Loss: 0.0322\n",
      "Epoch [5/5], Loss: 0.0185\n",
      "Client 7 Private Data Validation Accuracy: 66.67%\n",
      "Client 7 - decrease-temp to 0.208532\n",
      "\n",
      "--- Client 8 Local Training with T=0.1047, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.0992\n",
      "Epoch [2/5], Loss: 0.0298\n",
      "Epoch [3/5], Loss: 0.0164\n",
      "Epoch [4/5], Loss: 0.0218\n",
      "Epoch [5/5], Loss: 0.0260\n",
      "Client 8 Private Data Validation Accuracy: 77.73%\n",
      "\n",
      "--- Client 9 Local Training with T=0.1102, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.3456\n",
      "Epoch [2/5], Loss: 0.0980\n",
      "Epoch [3/5], Loss: 0.1452\n",
      "Epoch [4/5], Loss: 0.0733\n",
      "Epoch [5/5], Loss: 0.0308\n",
      "Client 9 Private Data Validation Accuracy: 62.50%\n",
      "\n",
      "--- Server Evaluation ---\n",
      "Server Test Data Validation Accuracy: 51.68%\n",
      "\n",
      "\n",
      "##################################################\n",
      "# RUNNING FedChill with Scaling Factor 2.5\n",
      "##################################################\n",
      "\n",
      "Data partitioning complete.\n",
      "Client 0 - Het Score: 0.9007, Initial temp: 0.1052\n",
      "Client 1 - Het Score: 0.7295, Initial temp: 0.1614\n",
      "Client 2 - Het Score: 0.8682, Initial temp: 0.1141\n",
      "Client 3 - Het Score: 0.8552, Initial temp: 0.1179\n",
      "Client 4 - Het Score: 0.5698, Initial temp: 0.2406\n",
      "Client 5 - Het Score: 0.5956, Initial temp: 0.2256\n",
      "Client 6 - Het Score: 0.7451, Initial temp: 0.1552\n",
      "Client 7 - Het Score: 0.6556, Initial temp: 0.1942\n",
      "Client 8 - Het Score: 1.0000, Initial temp: 0.0821\n",
      "Client 9 - Het Score: 1.0000, Initial temp: 0.0821\n",
      "\n",
      "==================== COMMUNICATION ROUND 1 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 2 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 3 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 4 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 5 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 6 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 7 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 8 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 9 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 10 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 11 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 12 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 13 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 14 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 15 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 16 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 17 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 18 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 19 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 20 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 21 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 22 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 23 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 24 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 25 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 26 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 27 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 28 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 29 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 30 ====================\n",
      "\n",
      "--- LOCAL TRAINING OF CLIENTS ---\n",
      "\n",
      "--- Client 0 Local Training with T=0.0902, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.1288\n",
      "Epoch [2/5], Loss: 0.0373\n",
      "Epoch [3/5], Loss: 0.0812\n",
      "Epoch [4/5], Loss: 0.0685\n",
      "Epoch [5/5], Loss: 0.0580\n",
      "Client 0 Private Data Validation Accuracy: 50.94%\n",
      "Client 0 - decrease-temp to 0.085699\n",
      "\n",
      "--- Client 1 Local Training with T=0.1315, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 2.2020\n",
      "Epoch [2/5], Loss: 0.3799\n",
      "Epoch [3/5], Loss: 0.1516\n",
      "Epoch [4/5], Loss: 0.0858\n",
      "Epoch [5/5], Loss: 0.0598\n",
      "Client 1 Private Data Validation Accuracy: 64.06%\n",
      "\n",
      "--- Client 2 Local Training with T=0.0930, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 1.5701\n",
      "Epoch [2/5], Loss: 0.4933\n",
      "Epoch [3/5], Loss: 0.2113\n",
      "Epoch [4/5], Loss: 0.1472\n",
      "Epoch [5/5], Loss: 0.1267\n",
      "Client 2 Private Data Validation Accuracy: 65.62%\n",
      "Client 2 - decrease-temp to 0.088306\n",
      "\n",
      "--- Client 3 Local Training with T=0.0960, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.4173\n",
      "Epoch [2/5], Loss: 0.0327\n",
      "Epoch [3/5], Loss: 0.0438\n",
      "Epoch [4/5], Loss: 0.0176\n",
      "Epoch [5/5], Loss: 0.0292\n",
      "Client 3 Private Data Validation Accuracy: 89.58%\n",
      "Client 3 - decrease-temp to 0.091233\n",
      "\n",
      "--- Client 4 Local Training with T=0.2063, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.1708\n",
      "Epoch [2/5], Loss: 0.0766\n",
      "Epoch [3/5], Loss: 0.0492\n",
      "Epoch [4/5], Loss: 0.0203\n",
      "Epoch [5/5], Loss: 0.0221\n",
      "Client 4 Private Data Validation Accuracy: 57.81%\n",
      "\n",
      "--- Client 5 Local Training with T=0.1746, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.4339\n",
      "Epoch [2/5], Loss: 0.2364\n",
      "Epoch [3/5], Loss: 0.1182\n",
      "Epoch [4/5], Loss: 0.0349\n",
      "Epoch [5/5], Loss: 0.0396\n",
      "Client 5 Private Data Validation Accuracy: 61.72%\n",
      "\n",
      "--- Client 6 Local Training with T=0.1264, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 1.4023\n",
      "Epoch [2/5], Loss: 0.2543\n",
      "Epoch [3/5], Loss: 0.1685\n",
      "Epoch [4/5], Loss: 0.0933\n",
      "Epoch [5/5], Loss: 0.0292\n",
      "Client 6 Private Data Validation Accuracy: 70.31%\n",
      "\n",
      "--- Client 7 Local Training with T=0.1845, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.2090\n",
      "Epoch [2/5], Loss: 0.0755\n",
      "Epoch [3/5], Loss: 0.0411\n",
      "Epoch [4/5], Loss: 0.0238\n",
      "Epoch [5/5], Loss: 0.0203\n",
      "Client 7 Private Data Validation Accuracy: 70.83%\n",
      "\n",
      "--- Client 8 Local Training with T=0.0635, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 1.5449\n",
      "Epoch [2/5], Loss: 0.3263\n",
      "Epoch [3/5], Loss: 0.2062\n",
      "Epoch [4/5], Loss: 0.1376\n",
      "Epoch [5/5], Loss: 0.1110\n",
      "Client 8 Private Data Validation Accuracy: 78.12%\n",
      "\n",
      "--- Client 9 Local Training with T=0.0704, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.5738\n",
      "Epoch [2/5], Loss: 0.2275\n",
      "Epoch [3/5], Loss: 0.0938\n",
      "Epoch [4/5], Loss: 0.1917\n",
      "Epoch [5/5], Loss: 0.1564\n",
      "Client 9 Private Data Validation Accuracy: 53.12%\n",
      "\n",
      "--- Server Evaluation ---\n",
      "Server Test Data Validation Accuracy: 48.62%\n",
      "\n",
      "\n",
      "##################################################\n",
      "# RUNNING FedChill with Scaling Factor 3.0\n",
      "##################################################\n",
      "\n",
      "Data partitioning complete.\n",
      "Client 0 - Het Score: 0.9007, Initial temp: 0.0671\n",
      "Client 1 - Het Score: 0.7295, Initial temp: 0.1121\n",
      "Client 2 - Het Score: 0.8682, Initial temp: 0.0739\n",
      "Client 3 - Het Score: 0.8552, Initial temp: 0.0769\n",
      "Client 4 - Het Score: 0.5698, Initial temp: 0.1810\n",
      "Client 5 - Het Score: 0.5956, Initial temp: 0.1675\n",
      "Client 6 - Het Score: 0.7451, Initial temp: 0.1070\n",
      "Client 7 - Het Score: 0.6556, Initial temp: 0.1399\n",
      "Client 8 - Het Score: 1.0000, Initial temp: 0.0500\n",
      "Client 9 - Het Score: 1.0000, Initial temp: 0.0500\n",
      "\n",
      "==================== COMMUNICATION ROUND 1 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 2 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 3 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 4 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 5 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 6 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 7 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 8 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 9 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 10 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 11 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 12 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 13 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 14 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 15 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 16 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 17 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 18 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 19 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 20 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 21 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 22 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 23 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 24 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 25 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 26 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 27 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 28 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 29 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 30 ====================\n",
      "\n",
      "--- LOCAL TRAINING OF CLIENTS ---\n",
      "\n",
      "--- Client 0 Local Training with T=0.0605, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.2173\n",
      "Epoch [2/5], Loss: 0.0827\n",
      "Epoch [3/5], Loss: 0.0632\n",
      "Epoch [4/5], Loss: 0.0745\n",
      "Epoch [5/5], Loss: 0.1148\n",
      "Client 0 Private Data Validation Accuracy: 61.25%\n",
      "\n",
      "--- Client 1 Local Training with T=0.0961, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 1.1691\n",
      "Epoch [2/5], Loss: 0.3396\n",
      "Epoch [3/5], Loss: 0.1622\n",
      "Epoch [4/5], Loss: 0.1590\n",
      "Epoch [5/5], Loss: 0.0901\n",
      "Client 1 Private Data Validation Accuracy: 65.62%\n",
      "\n",
      "--- Client 2 Local Training with T=0.0667, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 1.3833\n",
      "Epoch [2/5], Loss: 0.2918\n",
      "Epoch [3/5], Loss: 0.1962\n",
      "Epoch [4/5], Loss: 0.1603\n",
      "Epoch [5/5], Loss: 0.2360\n",
      "Client 2 Private Data Validation Accuracy: 69.79%\n",
      "\n",
      "--- Client 3 Local Training with T=0.0694, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.5214\n",
      "Epoch [2/5], Loss: 0.0650\n",
      "Epoch [3/5], Loss: 0.0423\n",
      "Epoch [4/5], Loss: 0.0407\n",
      "Epoch [5/5], Loss: 0.0439\n",
      "Client 3 Private Data Validation Accuracy: 91.67%\n",
      "\n",
      "--- Client 4 Local Training with T=0.1474, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.2149\n",
      "Epoch [2/5], Loss: 0.1023\n",
      "Epoch [3/5], Loss: 0.0581\n",
      "Epoch [4/5], Loss: 0.0327\n",
      "Epoch [5/5], Loss: 0.0576\n",
      "Client 4 Private Data Validation Accuracy: 56.25%\n",
      "Client 4 - decrease-temp to 0.140022\n",
      "\n",
      "--- Client 5 Local Training with T=0.1591, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.5520\n",
      "Epoch [2/5], Loss: 0.4103\n",
      "Epoch [3/5], Loss: 0.2405\n",
      "Epoch [4/5], Loss: 0.1803\n",
      "Epoch [5/5], Loss: 0.1225\n",
      "Client 5 Private Data Validation Accuracy: 47.66%\n",
      "\n",
      "--- Client 6 Local Training with T=0.0871, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 1.2680\n",
      "Epoch [2/5], Loss: 0.2765\n",
      "Epoch [3/5], Loss: 0.2505\n",
      "Epoch [4/5], Loss: 0.2162\n",
      "Epoch [5/5], Loss: 0.1053\n",
      "Client 6 Private Data Validation Accuracy: 59.38%\n",
      "\n",
      "--- Client 7 Local Training with T=0.1263, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.2704\n",
      "Epoch [2/5], Loss: 0.0770\n",
      "Epoch [3/5], Loss: 0.0727\n",
      "Epoch [4/5], Loss: 0.0804\n",
      "Epoch [5/5], Loss: 0.0797\n",
      "Client 7 Private Data Validation Accuracy: 59.38%\n",
      "Client 7 - decrease-temp to 0.119951\n",
      "\n",
      "--- Client 8 Local Training with T=0.0500, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.3218\n",
      "Epoch [2/5], Loss: 0.1083\n",
      "Epoch [3/5], Loss: 0.0903\n",
      "Epoch [4/5], Loss: 0.0479\n",
      "Epoch [5/5], Loss: 0.0851\n",
      "Client 8 Private Data Validation Accuracy: 76.95%\n",
      "\n",
      "--- Client 9 Local Training with T=0.0500, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 1.5622\n",
      "Epoch [2/5], Loss: 0.5396\n",
      "Epoch [3/5], Loss: 0.2788\n",
      "Epoch [4/5], Loss: 0.2678\n",
      "Epoch [5/5], Loss: 0.1752\n",
      "Client 9 Private Data Validation Accuracy: 62.50%\n",
      "\n",
      "--- Server Evaluation ---\n",
      "Server Test Data Validation Accuracy: 49.39%\n",
      "\n",
      "\n",
      "##################################################\n",
      "# RUNNING FedChill with Scaling Factor 3.5\n",
      "##################################################\n",
      "\n",
      "Data partitioning complete.\n",
      "Client 0 - Het Score: 0.9007, Initial temp: 0.0500\n",
      "Client 1 - Het Score: 0.7295, Initial temp: 0.0778\n",
      "Client 2 - Het Score: 0.8682, Initial temp: 0.0500\n",
      "Client 3 - Het Score: 0.8552, Initial temp: 0.0501\n",
      "Client 4 - Het Score: 0.5698, Initial temp: 0.1361\n",
      "Client 5 - Het Score: 0.5956, Initial temp: 0.1243\n",
      "Client 6 - Het Score: 0.7451, Initial temp: 0.0737\n",
      "Client 7 - Het Score: 0.6556, Initial temp: 0.1008\n",
      "Client 8 - Het Score: 1.0000, Initial temp: 0.0500\n",
      "Client 9 - Het Score: 1.0000, Initial temp: 0.0500\n",
      "\n",
      "==================== COMMUNICATION ROUND 1 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 2 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 3 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 4 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 5 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 6 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 7 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 8 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 9 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 10 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 11 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 12 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 13 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 14 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 15 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 16 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 17 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 18 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 19 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 20 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 21 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 22 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 23 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 24 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 25 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 26 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 27 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 28 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 29 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 30 ====================\n",
      "\n",
      "--- LOCAL TRAINING OF CLIENTS ---\n",
      "\n",
      "--- Client 0 Local Training with T=0.0500, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.1791\n",
      "Epoch [2/5], Loss: 0.0946\n",
      "Epoch [3/5], Loss: 0.1267\n",
      "Epoch [4/5], Loss: 0.0986\n",
      "Epoch [5/5], Loss: 0.0931\n",
      "Client 0 Private Data Validation Accuracy: 57.50%\n",
      "\n",
      "--- Client 1 Local Training with T=0.0634, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 1.4835\n",
      "Epoch [2/5], Loss: 0.5102\n",
      "Epoch [3/5], Loss: 0.2539\n",
      "Epoch [4/5], Loss: 0.1646\n",
      "Epoch [5/5], Loss: 0.1365\n",
      "Client 1 Private Data Validation Accuracy: 59.38%\n",
      "\n",
      "--- Client 2 Local Training with T=0.0500, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 1.9235\n",
      "Epoch [2/5], Loss: 0.5698\n",
      "Epoch [3/5], Loss: 0.3166\n",
      "Epoch [4/5], Loss: 0.1838\n",
      "Epoch [5/5], Loss: 0.1494\n",
      "Client 2 Private Data Validation Accuracy: 69.79%\n",
      "\n",
      "--- Client 3 Local Training with T=0.0501, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.5632\n",
      "Epoch [2/5], Loss: 0.2526\n",
      "Epoch [3/5], Loss: 0.0746\n",
      "Epoch [4/5], Loss: 0.0650\n",
      "Epoch [5/5], Loss: 0.0786\n",
      "Client 3 Private Data Validation Accuracy: 92.19%\n",
      "\n",
      "--- Client 4 Local Training with T=0.1167, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.2499\n",
      "Epoch [2/5], Loss: 0.1202\n",
      "Epoch [3/5], Loss: 0.1213\n",
      "Epoch [4/5], Loss: 0.0933\n",
      "Epoch [5/5], Loss: 0.0410\n",
      "Client 4 Private Data Validation Accuracy: 56.25%\n",
      "Client 4 - decrease-temp to 0.110850\n",
      "\n",
      "--- Client 5 Local Training with T=0.1013, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.5899\n",
      "Epoch [2/5], Loss: 0.5923\n",
      "Epoch [3/5], Loss: 0.3810\n",
      "Epoch [4/5], Loss: 0.1504\n",
      "Epoch [5/5], Loss: 0.1012\n",
      "Client 5 Private Data Validation Accuracy: 57.81%\n",
      "\n",
      "--- Client 6 Local Training with T=0.0570, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 1.1734\n",
      "Epoch [2/5], Loss: 0.3292\n",
      "Epoch [3/5], Loss: 0.3586\n",
      "Epoch [4/5], Loss: 0.1552\n",
      "Epoch [5/5], Loss: 0.1541\n",
      "Client 6 Private Data Validation Accuracy: 65.62%\n",
      "\n",
      "--- Client 7 Local Training with T=0.1008, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.2427\n",
      "Epoch [2/5], Loss: 0.0794\n",
      "Epoch [3/5], Loss: 0.0489\n",
      "Epoch [4/5], Loss: 0.0720\n",
      "Epoch [5/5], Loss: 0.0907\n",
      "Client 7 Private Data Validation Accuracy: 61.46%\n",
      "\n",
      "--- Client 8 Local Training with T=0.0500, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.2540\n",
      "Epoch [2/5], Loss: 0.0726\n",
      "Epoch [3/5], Loss: 0.1089\n",
      "Epoch [4/5], Loss: 0.1346\n",
      "Epoch [5/5], Loss: 0.2201\n",
      "Client 8 Private Data Validation Accuracy: 73.05%\n",
      "\n",
      "--- Client 9 Local Training with T=0.0500, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 1.1099\n",
      "Epoch [2/5], Loss: 0.2896\n",
      "Epoch [3/5], Loss: 0.2721\n",
      "Epoch [4/5], Loss: 0.1762\n",
      "Epoch [5/5], Loss: 0.1159\n",
      "Client 9 Private Data Validation Accuracy: 61.72%\n",
      "\n",
      "--- Server Evaluation ---\n",
      "Server Test Data Validation Accuracy: 47.60%\n",
      "\n",
      "\n",
      "##################################################\n",
      "# RUNNING FedChill with Scaling Factor 4.0\n",
      "##################################################\n",
      "\n",
      "Data partitioning complete.\n",
      "Client 0 - Het Score: 0.9007, Initial temp: 0.0500\n",
      "Client 1 - Het Score: 0.7295, Initial temp: 0.0540\n",
      "Client 2 - Het Score: 0.8682, Initial temp: 0.0500\n",
      "Client 3 - Het Score: 0.8552, Initial temp: 0.0500\n",
      "Client 4 - Het Score: 0.5698, Initial temp: 0.1024\n",
      "Client 5 - Het Score: 0.5956, Initial temp: 0.0923\n",
      "Client 6 - Het Score: 0.7451, Initial temp: 0.0508\n",
      "Client 7 - Het Score: 0.6556, Initial temp: 0.0726\n",
      "Client 8 - Het Score: 1.0000, Initial temp: 0.0500\n",
      "Client 9 - Het Score: 1.0000, Initial temp: 0.0500\n",
      "\n",
      "==================== COMMUNICATION ROUND 1 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 2 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 3 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 4 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 5 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 6 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 7 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 8 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 9 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 10 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 11 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 12 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 13 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 14 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 15 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 16 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 17 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 18 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 19 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 20 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 21 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 22 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 23 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 24 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 25 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 26 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 27 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 28 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 29 ====================\n",
      "\n",
      "==================== COMMUNICATION ROUND 30 ====================\n",
      "\n",
      "--- LOCAL TRAINING OF CLIENTS ---\n",
      "\n",
      "--- Client 0 Local Training with T=0.0500, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.1830\n",
      "Epoch [2/5], Loss: 0.0988\n",
      "Epoch [3/5], Loss: 0.1408\n",
      "Epoch [4/5], Loss: 0.1374\n",
      "Epoch [5/5], Loss: 0.1370\n",
      "Client 0 Private Data Validation Accuracy: 59.69%\n",
      "\n",
      "--- Client 1 Local Training with T=0.0540, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 1.9192\n",
      "Epoch [2/5], Loss: 0.7738\n",
      "Epoch [3/5], Loss: 0.5350\n",
      "Epoch [4/5], Loss: 0.2131\n",
      "Epoch [5/5], Loss: 0.1268\n",
      "Client 1 Private Data Validation Accuracy: 64.06%\n",
      "\n",
      "--- Client 2 Local Training with T=0.0500, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 1.4807\n",
      "Epoch [2/5], Loss: 0.3468\n",
      "Epoch [3/5], Loss: 0.2176\n",
      "Epoch [4/5], Loss: 0.2268\n",
      "Epoch [5/5], Loss: 0.1491\n",
      "Client 2 Private Data Validation Accuracy: 66.67%\n",
      "\n",
      "--- Client 3 Local Training with T=0.0500, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.2150\n",
      "Epoch [2/5], Loss: 0.0470\n",
      "Epoch [3/5], Loss: 0.0504\n",
      "Epoch [4/5], Loss: 0.0164\n",
      "Epoch [5/5], Loss: 0.0453\n",
      "Client 3 Private Data Validation Accuracy: 86.98%\n",
      "\n",
      "--- Client 4 Local Training with T=0.0878, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.2636\n",
      "Epoch [2/5], Loss: 0.0973\n",
      "Epoch [3/5], Loss: 0.0843\n",
      "Epoch [4/5], Loss: 0.0623\n",
      "Epoch [5/5], Loss: 0.1342\n",
      "Client 4 Private Data Validation Accuracy: 60.94%\n",
      "\n",
      "--- Client 5 Local Training with T=0.0752, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.7960\n",
      "Epoch [2/5], Loss: 0.4272\n",
      "Epoch [3/5], Loss: 0.1898\n",
      "Epoch [4/5], Loss: 0.1525\n",
      "Epoch [5/5], Loss: 0.1382\n",
      "Client 5 Private Data Validation Accuracy: 53.91%\n",
      "\n",
      "--- Client 6 Local Training with T=0.0508, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 1.3762\n",
      "Epoch [2/5], Loss: 0.4541\n",
      "Epoch [3/5], Loss: 0.2155\n",
      "Epoch [4/5], Loss: 0.1404\n",
      "Epoch [5/5], Loss: 0.0898\n",
      "Client 6 Private Data Validation Accuracy: 68.75%\n",
      "\n",
      "--- Client 7 Local Training with T=0.0623, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.2170\n",
      "Epoch [2/5], Loss: 0.1496\n",
      "Epoch [3/5], Loss: 0.0769\n",
      "Epoch [4/5], Loss: 0.0620\n",
      "Epoch [5/5], Loss: 0.1211\n",
      "Client 7 Private Data Validation Accuracy: 55.21%\n",
      "Client 7 - decrease-temp to 0.059157\n",
      "\n",
      "--- Client 8 Local Training with T=0.0500, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.2409\n",
      "Epoch [2/5], Loss: 0.0623\n",
      "Epoch [3/5], Loss: 0.0841\n",
      "Epoch [4/5], Loss: 0.0470\n",
      "Epoch [5/5], Loss: 0.1889\n",
      "Client 8 Private Data Validation Accuracy: 75.39%\n",
      "\n",
      "--- Client 9 Local Training with T=0.0500, LR=0.010000 ---\n",
      "Epoch [1/5], Loss: 0.4801\n",
      "Epoch [2/5], Loss: 0.1715\n",
      "Epoch [3/5], Loss: 0.0953\n",
      "Epoch [4/5], Loss: 0.0997\n",
      "Epoch [5/5], Loss: 0.1910\n",
      "Client 9 Private Data Validation Accuracy: 59.38%\n",
      "\n",
      "--- Server Evaluation ---\n",
      "Server Test Data Validation Accuracy: 49.26%\n"
     ]
    },
    {
     "ename": "NameError",
     "evalue": "name 'model_name' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m/tmp/ipykernel_36/2023819718.py\u001b[0m in \u001b[0;36m<cell line: 0>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mresults\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrun_experiment\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      2\u001b[0m \u001b[0mplot_model_comparison\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresults\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      3\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      4\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresults\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/tmp/ipykernel_36/3667192117.py\u001b[0m in \u001b[0;36mrun_experiment\u001b[0;34m()\u001b[0m\n\u001b[1;32m     23\u001b[0m         }\n\u001b[1;32m     24\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 25\u001b[0;31m     \u001b[0mplot_model_comparison\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel_results\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     26\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     27\u001b[0m     \u001b[0;32mreturn\u001b[0m \u001b[0mmodel_results\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mNameError\u001b[0m: name 'model_name' is not defined"
     ]
    }
   ],
   "source": [
    "results = run_experiment()\n",
    "plot_model_comparison(results)\n",
    "\n",
    "print(results)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "nYAg_h1n7NAt"
   },
   "source": [
    "---"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "gpuType": "T4",
   "provenance": []
  },
  "kaggle": {
   "accelerator": "nvidiaTeslaT4",
   "dataSources": [],
   "dockerImageVersionId": 31089,
   "isGpuEnabled": true,
   "isInternetEnabled": true,
   "language": "python",
   "sourceType": "notebook"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
