{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "3f0f7008-7fd8-4563-809d-bb82777ef038",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "import copy\n",
    "import wandb\n",
    "from ddro_utils_20230508 import DatasetSplit, average_weights, cifar_iid\n",
    "from myutils import save_checkpoint_epoch, ResultsLog\n",
    "from myDataLoader import get_train_val_test_loader\n",
    "\n",
    "def DS_FedDRO(args, model, train_data, dict_users, num_epochs, I, eta_x, eta_y, beta, lmbda, device='cuda'):\n",
    "    \"\"\"\n",
    "    DS-FedDRO: Federated Non-Convex CO algorithm with 2-Sided Learning Rate (Algorithm 3).\n",
    "\n",
    "    Args:\n",
    "        args: Command-line arguments for training configuration.\n",
    "        model: PyTorch model architecture (e.g., ResNet).\n",
    "        train_data: The full training dataset.\n",
    "        dict_users: The dictionary mapping each client to their data indices.\n",
    "        num_epochs: Number of epochs for training.\n",
    "        I: Communication interval for global aggregation.\n",
    "        eta_x: Learning rate for model parameters x.\n",
    "        eta_y: Learning rate for compositional embeddings y.\n",
    "        beta: Momentum parameter for embedding updates.\n",
    "        lmbda: Regularization parameter for the DRO loss.\n",
    "        device: 'cuda' or 'cpu'.\n",
    "    \n",
    "    Returns:\n",
    "        global_model: The trained global model after federated learning.\n",
    "        y_k: Final embeddings for each client.\n",
    "    \"\"\"\n",
    "    \n",
    "    # Initialize global model and embedding vectors y_k for each client\n",
    "    global_model = copy.deepcopy(model).to(device)\n",
    "    y_k = [0.0 for _ in range(args.num_users)]  # Initialize embeddings for each client\n",
    "\n",
    "    # Initialize WandB logging\n",
    "    wandb.init(config=args, project=\"DS_FedDRO_Project\", entity=\"your_username\")\n",
    "\n",
    "    # Initialize Results Logging\n",
    "    results_file = os.path.join(args.results_dir, args.res_filename + '_results.csv')\n",
    "    results = ResultsLog(results_file)\n",
    "\n",
    "    for epoch in range(num_epochs):\n",
    "        print(f'\\n | Global Training Round : {epoch+1} |\\n')\n",
    "\n",
    "        global_model.train()\n",
    "\n",
    "        local_weights, local_yk_updates = [], []\n",
    "        train_loss, train_acc = 0, 0\n",
    "        for idx in range(args.num_users):\n",
    "            local_model = copy.deepcopy(global_model)\n",
    "            local_model.train()\n",
    "            \n",
    "            # Load data for the current client\n",
    "            train_data_idx = DatasetSplit(train_data, dict_users[idx])\n",
    "            train_loader_idx = torch.utils.data.DataLoader(train_data_idx, \n",
    "                                                           batch_size=args.local_bs, \n",
    "                                                           shuffle=True, \n",
    "                                                           num_workers=args.works, \n",
    "                                                           pin_memory=True)\n",
    "\n",
    "            # Momentum-based update of embeddings y_k\n",
    "            ykt = update_ykt_batch(model_pre=local_model, model_cur=local_model, global_round=epoch, \n",
    "                                   ykt=y_k[idx], trainloader=train_loader_idx, beta=beta, lmbda=lmbda)\n",
    "\n",
    "            # Local update of model parameters using two-sided learning rates\n",
    "            local_weights_client = update_x_k(local_model, ykt, train_loader_idx, eta_x, eta_y, lmbda, device=device)\n",
    "\n",
    "            local_weights.append(copy.deepcopy(local_weights_client))\n",
    "            local_yk_updates.append(ykt)\n",
    "        \n",
    "            # Calculate loss and accuracy for each client\n",
    "            client_loss, client_acc = calculate_metrics(local_model, train_loader_idx, device)\n",
    "            train_loss += client_loss\n",
    "            train_acc += client_acc\n",
    "\n",
    "        # Aggregation step: Perform global aggregation every I-th iteration\n",
    "        if (epoch + 1) % I == 0:\n",
    "            # Global aggregation of model parameters and embeddings\n",
    "            global_weights = average_weights(local_weights)\n",
    "            global_model.load_state_dict(global_weights)\n",
    "            \n",
    "            # Optional: Aggregate embeddings (y_k) across clients\n",
    "            y_k = [(1 / args.num_users) * sum(local_yk_updates)] * args.num_users\n",
    "\n",
    "        # Calculate average loss and accuracy for this epoch\n",
    "        train_loss /= args.num_users\n",
    "        train_acc /= args.num_users\n",
    "\n",
    "        # Log metrics to WandB\n",
    "        wandb.log({\"train_loss\": train_loss, \"train_accuracy\": train_acc}, step=epoch)\n",
    "\n",
    "        # Log results to CSV\n",
    "        results.add(epoch=epoch+1, train_loss=train_loss, train_acc=train_acc)\n",
    "        results.save()\n",
    "\n",
    "        # Save model checkpoint every few epochs\n",
    "        is_best = False  # You can adjust the logic for saving the best model\n",
    "        save_checkpoint_epoch({\n",
    "            'epoch': epoch + 1,\n",
    "            'state_dict': global_model.state_dict(),\n",
    "            'best_prec1': train_acc,\n",
    "        }, is_best, path=args.results_dir)\n",
    "\n",
    "        print(f\"Epoch {epoch+1} completed. Loss: {train_loss:.4f}, Accuracy: {train_acc:.2f}%\")\n",
    "\n",
    "    return global_model, y_k\n",
    "\n",
    "def update_x_k(model, y_t, trainloader, eta_x, eta_y, lmbda, device='cuda'):\n",
    "    \"\"\"\n",
    "    Local update of model parameters x_k using two-sided learning rates.\n",
    "    \n",
    "    Args:\n",
    "        model: Local client model.\n",
    "        y_t: Embedding y_k for the client.\n",
    "        trainloader: DataLoader for the client's dataset.\n",
    "        eta_x: Learning rate for model parameters x_k.\n",
    "        eta_y: Learning rate for embeddings y_k.\n",
    "        lmbda: Regularization parameter for DRO loss.\n",
    "        device: 'cuda' or 'cpu'.\n",
    "    \n",
    "    Returns:\n",
    "        The updated model parameters (state_dict).\n",
    "    \"\"\"\n",
    "    criterion = nn.CrossEntropyLoss(reduction='none')\n",
    "    model.train()\n",
    "    model.to(device)\n",
    "    \n",
    "    optimizer = torch.optim.SGD(model.parameters(), lr=eta_x)\n",
    "\n",
    "    for batch in trainloader:\n",
    "        images, labels = batch\n",
    "        images, labels = images.to(device), labels.to(device)\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        outputs = model(images)\n",
    "        loss = criterion(outputs, labels)\n",
    "        g_obj = torch.mean(torch.exp(loss / lmbda))\n",
    "        \n",
    "        f_obj = torch.log(g_obj)\n",
    "        f_obj.backward()\n",
    "        \n",
    "        optimizer.step()\n",
    "\n",
    "    return model.state_dict()\n",
    "\n",
    "def calculate_metrics(model, dataloader, device='cuda'):\n",
    "    \"\"\"\n",
    "    Calculates the loss and accuracy for the given model and dataloader.\n",
    "    \"\"\"\n",
    "    criterion = nn.CrossEntropyLoss()\n",
    "    model.eval()\n",
    "    total_loss, total_correct = 0, 0\n",
    "    total_samples = 0\n",
    "\n",
    "    with torch.no_grad():\n",
    "        for images, labels in dataloader:\n",
    "            images, labels = images.to(device), labels.to(device)\n",
    "            outputs = model(images)\n",
    "            loss = criterion(outputs, labels)\n",
    "            total_loss += loss.item() * images.size(0)\n",
    "            _, predicted = outputs.max(1)\n",
    "            total_correct += predicted.eq(labels).sum().item()\n",
    "            total_samples += images.size(0)\n",
    "\n",
    "    avg_loss = total_loss / total_samples\n",
    "    accuracy = 100. * total_correct / total_samples\n",
    "    return avg_loss, accuracy\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "2efb3e48-80bb-404a-99ab-a5d864c9f2cb",
   "metadata": {},
   "outputs": [
    {
     "ename": "SyntaxError",
     "evalue": "invalid syntax (2980528579.py, line 1)",
     "output_type": "error",
     "traceback": [
      "\u001b[0;36m  Cell \u001b[0;32mIn[4], line 1\u001b[0;36m\u001b[0m\n\u001b[0;31m    conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch\u001b[0m\n\u001b[0m          ^\u001b[0m\n\u001b[0;31mSyntaxError\u001b[0m\u001b[0;31m:\u001b[0m invalid syntax\n"
     ]
    }
   ],
   "source": [
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b412c770-c0a2-4086-87a4-3620286b85d3",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f9c5e8a-8f35-45f9-8335-949c11689475",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "31d2d27c-e606-4dfe-8206-64e7c1283b50",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
