{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "\n",
    "from configuration import config_jup\n",
    "from utils.data_loader import get_loader_all_clients\n",
    "from utils.train_utils import get_free_gpu_idx, get_logger, initialize_clients, FedAvg, weightedFedAvg, test_global_model, save_results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "args = config_jup.base_parser()\n",
    "logger = get_logger(args)\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    args.cuda = True\n",
    "    args.device = f'cuda:0'\n",
    "else:\n",
    "    args.device = 'cpu' "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "args = config_jup.base_parser()\n",
    "logger = get_logger(args)\n",
    "print(args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for run in range(args.n_runs):\n",
    "    loader_clients, cls_assignment_list, global_test_loader = get_loader_all_clients(args, run)\n",
    "    clients = initialize_clients(args, loader_clients, cls_assignment_list, run)\n",
    "\n",
    "    while not all([client.train_completed for client in clients]):\n",
    "        for client in clients:\n",
    "            if not client.train_completed:\n",
    "                samples, labels = client.get_next_batch()\n",
    "                if samples is not None:\n",
    "                    if args.with_memory:\n",
    "                        if client.task_id == 0:\n",
    "                            client.train_with_update(samples, labels)\n",
    "                        else:\n",
    "                            client.train_with_memory(samples, labels)\n",
    "                    else:\n",
    "                        client.train(samples, labels)\n",
    "                else:\n",
    "                    print(f'Run {run} - Client {client.client_id} - Task {client.task_id} completed - {client.get_current_task()}')\n",
    "                    # compute loss train\n",
    "                    logger = client.compute_loss(logger, run)\n",
    "                    print(f'Run {run} - Client {client.client_id} - Test time - Task {client.task_id}')\n",
    "                    logger = client.test(logger, run)\n",
    "                    logger = client.validation(logger, run)\n",
    "                    logger = client.forgetting(logger, run)\n",
    "\n",
    "                    if client.task_id + 1 >= args.n_tasks:\n",
    "                        client.train_completed = True\n",
    "                        print(f'Run {run} - Client {client.client_id} - Train completed')\n",
    "                        logger = client.balanced_accuracy(logger, run)\n",
    "                    else:\n",
    "                        client.task_id += 1\n",
    "\n",
    "        # COMMUNICATION ROUND PART\n",
    "        selected_clients = [client.client_id for client in clients if (client.num_batches >= args.burnin and client.num_batches % args.jump == 0 and client.train_completed == False)]\n",
    "        if len(selected_clients) > 1:\n",
    "            # communication round when all clients process a mini-batch\n",
    "            if args.fl_update.startswith('w_'):\n",
    "                global_model = weightedFedAvg(args, selected_clients, clients)\n",
    "            else:\n",
    "                global_model = FedAvg(args, selected_clients, clients)\n",
    "\n",
    "            global_parameters = global_model.state_dict()\n",
    "            # local models update with averaged global parameters\n",
    "            for client_id in selected_clients:\n",
    "                clients[client_id].save_last_local_model()\n",
    "                clients[client_id].update_parameters(global_parameters)\n",
    "                clients[client_id].save_last_global_model(global_model)\n",
    "\n",
    "    # global model accuracy when all clients finish their training on all tasks (FedCIL ICLR2023)\n",
    "    logger = test_global_model(args, global_test_loader, global_model, logger, run)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_avg_acc_for(args, logger):\n",
    "    final_accs = []\n",
    "    final_fors = []\n",
    "    for client_id in range(args.n_clients):\n",
    "        final_acc = np.mean(np.mean(logger[\"test\"][\"acc\"][client_id], 0)[args.n_tasks-1,:], 0)\n",
    "        final_for = np.mean(logger[\"test\"][\"forget\"][client_id])\n",
    "        final_accs.append(final_acc)\n",
    "        final_fors.append(final_for)\n",
    "    return np.mean(final_accs), np.std(final_accs), np.mean(final_fors), np.std(final_fors)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for client_id in range(args.n_clients):\n",
    "    print(f'Client {client_id}: {clients[client_id].task_list}')\n",
    "    print(np.mean(logger['test']['acc'][client_id], 0))\n",
    "    print(f'Final client accuracy: {np.mean(np.mean(logger[\"test\"][\"acc\"][client_id], 0)[args.n_tasks-1,:], 0)}')\n",
    "    print(f'Final client forgetting: {np.mean(logger[\"test\"][\"forget\"][client_id])}')\n",
    "    print(f'Final client balanced accuracy: {np.mean(logger[\"test\"][\"bal_acc\"][client_id])}')\n",
    "    print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mean_acc, std_acc, mean_for, std_for = compute_avg_acc_for(args, logger)\n",
    "print(f'Final average accuracy: {mean_acc:0.4f} (+-) {std_acc:0.4f}')\n",
    "print(f'Final average forgetting: {mean_for:0.4f} (+-) {std_for:0.4f}')\n",
    "print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# save training results\n",
    "save_results(args, logger)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
