{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2bdd9a42",
   "metadata": {},
   "outputs": [],
   "source": [
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import matplotlib.ticker as ticker\n",
    "\n",
    "def plot_with_min_marker(fig, model_sizes, losses, label, linestyle='-', marker='o', color=None, smooth=False):\n",
    "    if smooth:\n",
    "        losses = exponential_moving_average(losses)\n",
    "    sns.lineplot(x=model_sizes, y=losses, ax=fig, label=label, linestyle=linestyle, color=color)\n",
    "    min_index = np.argmin(losses)\n",
    "    fig.plot(model_sizes[min_index], losses[min_index], marker=marker, markersize=10, color=color)\n",
    "    \n",
    "def plot_with_max_marker(fig, model_sizes, accs, label, linestyle='-', marker='o', color=None, smooth=True):\n",
    "    if smooth:\n",
    "        accs = exponential_moving_average(accs)\n",
    "    sns.lineplot(x=model_sizes, y=accs, ax=fig, label=label, linestyle=linestyle, color=color)\n",
    "    max_index = np.argmax(accs)\n",
    "    \n",
    "def scatter_loss(fig, client_num, losses, label, color=None, smooth=False):\n",
    "    if smooth:\n",
    "        losses = exponential_moving_average(losses)\n",
    "    min_index = np.argmin(losses)\n",
    "    sns.scatterplot(x=client_num, y=losses[min_index], ax=fig, label=label, color=color)\n",
    "    \n",
    "def scatter_acc(fig, client_num, accs, label, color=None, smooth=False):\n",
    "    if smooth:\n",
    "        accs = exponential_moving_average(accs)\n",
    "    max_index = np.argmax(accs)\n",
    "    sns.scatterplot(x=client_num, y=accs[max_index], ax=fig, label=label, color=color)\n",
    "\n",
    "def exponential_moving_average(data, alpha=0.4):\n",
    "    ema = np.zeros_like(data)\n",
    "    ema[0] = data[0]\n",
    "    for i in range(1, len(data)):\n",
    "        ema[i] = alpha * data[i] + (1 - alpha) * ema[i - 1]\n",
    "    return ema\n",
    "\n",
    "def find_max_acc_model_size(accs, smooth=True, alpha=0.4):\n",
    "    if smooth:\n",
    "        accs = exponential_moving_average(accs, alpha)\n",
    "    max_index = np.argmax(accs)\n",
    "    return depths_list[max_index]\n",
    "\n",
    "def get_losses(loss_record):\n",
    "    return [float(v[\"losses\"][-1]) for v in loss_record.values()]\n",
    "\n",
    "def get_accs(record):\n",
    "    return [float(v[\"test_acc\"]) for k, v in record.items()]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f4fdeb40",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Draw figures for the experimental results about study the impact of n\n",
    "\n",
    "# Set Seaborn style with a more visually appealing color palette\n",
    "sns.set(style=\"whitegrid\", palette=\"bright\")  # You can try \"deep\", \"muted\", \"colorblind\", etc.\n",
    "\n",
    "number_of_clients = #\n",
    "\n",
    "vit_mini_test_accuracies = #\n",
    "cnn_mini_test_accuracies = #\n",
    "vit_cifar10_test_accuracies = #\n",
    "cnn_cifar10_test_accuracies = #\n",
    "\n",
    "vit_mini_test_accuracies = exponential_moving_average(vit_mini_test_accuracies)\n",
    "cnn_mini_test_accuracies = exponential_moving_average(cnn_mini_test_accuracies)\n",
    "vit_cifar10_test_accuracies = exponential_moving_average(vit_cifar10_test_accuracies)\n",
    "cnn_cifar10_test_accuracies = exponential_moving_average(cnn_cifar10_test_accuracies)\n",
    "\n",
    "# Set up the figure and axes for 3 horizontal subplots\n",
    "fig, axes = plt.subplots(1, 2, figsize=(12, 5), sharey=True)\n",
    "\n",
    "labelsize = 14\n",
    "title_fontsize = 18\n",
    "label_fontsize = 18\n",
    "\n",
    "# Plotting the data for each set of accuracies in different subplots\n",
    "sns.lineplot(ax=axes[0], x=number_of_clients, y=vit_mini_test_accuracies,  color=sns.color_palette()[0], legend='brief', label='Vit-OneBlock')\n",
    "sns.lineplot(ax=axes[0], x=number_of_clients, y=cnn_mini_test_accuracies, color=sns.color_palette()[1], legend='brief', label='ResNet-18')\n",
    "axes[0].set_title('Impact of n on Mini-Imagenet', fontsize=title_fontsize)\n",
    "axes[0].set_xlabel('Number of Clients', fontsize=label_fontsize)\n",
    "axes[0].set_ylabel('Test Accuracy', fontsize=label_fontsize)\n",
    "axes[0].tick_params(axis='both', which='major', labelsize=labelsize)\n",
    "\n",
    "sns.lineplot(ax=axes[1], x=number_of_clients, y=vit_cifar10_test_accuracies, color=sns.color_palette()[0], legend='brief', label='Vit-OneBlock')\n",
    "sns.lineplot(ax=axes[1], x=number_of_clients, y=cnn_cifar10_test_accuracies, color=sns.color_palette()[1], legend='brief', label='ResNet-18')\n",
    "axes[1].set_title('Impact of n on CIFAR-10', fontsize=title_fontsize)\n",
    "axes[1].set_xlabel('Number of Clients', fontsize=label_fontsize)\n",
    "axes[1].tick_params(axis='both', which='major', labelsize=labelsize)\n",
    "axes[1].legend()\n",
    "\n",
    "# Adjust layout\n",
    "plt.tight_layout()\n",
    "\n",
    "# plt.savefig(\"impact_of_n.pdf\", format='pdf')\n",
    "\n",
    "# Show the plot\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "62094cbb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Draw figures for the experimental results about study the impact of d\n",
    "\n",
    "# Set Seaborn style with a more visually appealing color palette\n",
    "sns.set(style=\"whitegrid\", palette=\"bright\")  # You can try \"deep\", \"muted\", \"colorblind\", etc.\n",
    "\n",
    "vit_model_size_list = #\n",
    "cnn_model_size_list = #\n",
    "\n",
    "vit_mini_cen_record = #\n",
    "vit_mini_decen20c_record = #\n",
    "cnn_cifar10_cen_record = #\n",
    "cnn_cifar10_decen20c_record = #\n",
    "vit_mini_cen_accs = get_accs(vit_mini_cen_record)\n",
    "vit_mini_decen20c_accs = get_accs(vit_mini_decen20c_record)\n",
    "cnn_cifar10_cen_accs = get_accs(cnn_cifar10_cen_record)\n",
    "cnn_cifar10_decen20c_accs = get_accs(cnn_cifar10_decen20c_record)\n",
    "gap_accs = [vit_mini_cen_accs[i] - vit_mini_decen20c_accs[i] for i in range(len(vit_mini_cen_accs))]\n",
    "gap_accs_2 = [cnn_cifar10_cen_accs[i] - cnn_cifar10_decen20c_accs[i] for i in range(len(cnn_cifar10_cen_accs))]\n",
    "\n",
    "\n",
    "vit_mini_cen_accs = exponential_moving_average(vit_mini_cen_accs)\n",
    "vit_mini_decen20c_accs = exponential_moving_average(vit_mini_decen20c_accs)\n",
    "cnn_cifar10_cen_accs = exponential_moving_average(cnn_cifar10_cen_accs)\n",
    "cnn_cifar10_decen20c_accs = exponential_moving_average(cnn_cifar10_decen20c_accs)\n",
    "gap_accs = exponential_moving_average(gap_accs)\n",
    "gap_accs_2 = exponential_moving_average(gap_accs_2)\n",
    "\n",
    "labelsize = 14\n",
    "title_fontsize = 18\n",
    "label_fontsize = 18\n",
    "\n",
    "fig, axes = plt.subplots(1, 2, figsize=(12, 5))\n",
    "\n",
    "sns.lineplot(ax=axes[0], x=vit_model_size_list, y=vit_mini_cen_accs, label='Centralized', color=sns.color_palette()[0])\n",
    "sns.lineplot(ax=axes[0], x=vit_model_size_list, y=vit_mini_decen20c_accs, label='Federated', color=sns.color_palette()[1])\n",
    "axes[0].set_title('Vit-OneBlock on Mini-ImageNet', fontsize=title_fontsize)\n",
    "axes[0].set_xlabel('Model Size', fontsize=label_fontsize)\n",
    "axes[0].set_ylabel('Test Accuracy', fontsize=label_fontsize)\n",
    "axes[0].tick_params(axis='both', which='major', labelsize=labelsize)\n",
    "axes[0].legend(loc='upper left')\n",
    "\n",
    "# Fill the area between the two lines\n",
    "axes[0].fill_between(\n",
    "    vit_model_size_list, \n",
    "    vit_mini_cen_accs, \n",
    "    vit_mini_decen20c_accs, \n",
    "    where=(vit_mini_cen_accs > vit_mini_decen20c_accs), \n",
    "    interpolate=True, \n",
    "    color='lightblue', \n",
    "    alpha=0.3,  # Transparency of the shaded area\n",
    "    label='Gap Area 1'\n",
    ")\n",
    "\n",
    "sns.lineplot(ax=axes[1], x=cnn_model_size_list, y=cnn_cifar10_cen_accs, label='Centralized', color=sns.color_palette()[0])\n",
    "sns.lineplot(ax=axes[1], x=cnn_model_size_list, y=cnn_cifar10_decen20c_accs, label='Federated', color=sns.color_palette()[1])\n",
    "axes[1].set_title('ResNet-18 on CIFAR-10', fontsize=title_fontsize)\n",
    "axes[1].set_xlabel('Model Size', fontsize=label_fontsize)\n",
    "axes[1].set_ylabel('Test Accuracy', fontsize=label_fontsize)\n",
    "axes[1].tick_params(axis='both', which='major', labelsize=labelsize)\n",
    "axes[1].legend(loc='upper left')\n",
    "\n",
    "# Fill the area between the two lines\n",
    "axes[1].fill_between(\n",
    "    cnn_model_size_list, \n",
    "    cnn_cifar10_cen_accs, \n",
    "    cnn_cifar10_decen20c_accs, \n",
    "    where=(cnn_cifar10_cen_accs > cnn_cifar10_decen20c_accs), \n",
    "    interpolate=True, \n",
    "    color='lightblue', \n",
    "    alpha=0.3,  # Transparency of the shaded area\n",
    "    label='Gap Area 2'\n",
    ")\n",
    "\n",
    "\n",
    "# Adjust layout\n",
    "plt.tight_layout()\n",
    "\n",
    "# plt.savefig(\"impact_of_d.pdf\", format='pdf')\n",
    "\n",
    "# Show the plot\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c133369d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Draw figures for the experimental results about bridging the gap by either increasing n or m\n",
    "\n",
    "# Set Seaborn style with a more visually appealing color palette\n",
    "sns.set(style=\"whitegrid\", palette=\"bright\")  # You can try \"deep\", \"muted\", \"colorblind\", etc.\n",
    "\n",
    "number_of_clients = #\n",
    "number_of_data = #\n",
    "\n",
    "# bridge by n\n",
    "vit_mini_cen_test_accuracy = #\n",
    "vit_mini_decen_test_accuracies = #\n",
    "# bridge by m\n",
    "vit_mini_cen_test_accuracy_2 = #\n",
    "vit_mini_decen_test_accuracies_2 = # \n",
    "\n",
    "vit_mini_decen_test_accuracies = exponential_moving_average(vit_mini_decen_test_accuracies)\n",
    "vit_mini_decen_test_accuracies_2 = exponential_moving_average(vit_mini_decen_test_accuracies_2)\n",
    "\n",
    "# Set up the figure and axes for 3 horizontal subplots\n",
    "fig, axes = plt.subplots(1, 2, figsize=(12, 5), sharey=True)\n",
    "\n",
    "labelsize = 14\n",
    "title_fontsize = 18\n",
    "label_fontsize = 18\n",
    "\n",
    "\n",
    "# Plotting the data for each set of accuracies in different subplots\n",
    "axes[0].axhline(y=vit_mini_cen_test_accuracy, color=sns.color_palette()[0], linestyle='--', label='Centralized / 4800 Data')\n",
    "sns.lineplot(ax=axes[0], x=number_of_clients, y=vit_mini_decen_test_accuracies,  color=sns.color_palette()[1], legend='brief', label='Federated')\n",
    "axes[0].set_title('Bridge Gap by Incorporating New Clients', fontsize=title_fontsize)\n",
    "axes[0].set_xlabel('Number of Clients', fontsize=label_fontsize)\n",
    "axes[0].set_ylabel('Test Accuracy', fontsize=label_fontsize)\n",
    "axes[0].tick_params(axis='both', which='major', labelsize=labelsize)\n",
    "axes[0].legend(loc='upper left')\n",
    "\n",
    "# Ensure x-axis shows only integer values\n",
    "axes[0].xaxis.set_major_locator(ticker.MaxNLocator(integer=True))\n",
    "\n",
    "axes[1].axhline(y=vit_mini_cen_test_accuracy_2, color=sns.color_palette()[0], linestyle='--', label='Centralized / 4800 Data')\n",
    "sns.lineplot(ax=axes[1], x=number_of_data, y=vit_mini_decen_test_accuracies_2,  color=sns.color_palette()[1], legend='brief', label='Federated'))\n",
    "axes[1].set_title('Bridge Gap by Adding Data to Existing Clients', fontsize=title_fontsize)\n",
    "axes[1].set_xlabel('Average Data Size across Clients', fontsize=label_fontsize)\n",
    "axes[1].tick_params(axis='both', which='major', labelsize=labelsize)\n",
    "axes[1].legend(loc='upper left')\n",
    "\n",
    "# Adjust layout\n",
    "plt.tight_layout()\n",
    "\n",
    "# plt.savefig(\"bridge_gap_by_data.pdf\", format='pdf')\n",
    "\n",
    "# Show the plot\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "09eb12e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Draw figures for the experimental results about bridging the gap by either increasing d or T\n",
    "\n",
    "\n",
    "# Set Seaborn style with a more visually appealing color palette\n",
    "sns.set(style=\"whitegrid\", palette=\"bright\")  # You can try \"deep\", \"muted\", \"colorblind\", etc.\n",
    "\n",
    "model_size_list =  #\n",
    "rounds_list = #\n",
    "\n",
    "# bridge by d\n",
    "vit_mini_cen_test_accuracy = #\n",
    "vit_mini_decen_test_accuracies = #\n",
    "# bridge by T\n",
    "vit_mini_cen_test_accuracy_2 = #\n",
    "vit_mini_decen_test_accuracies_2 = #\n",
    "\n",
    "\n",
    "vit_mini_decen_test_accuracies = exponential_moving_average(vit_mini_decen_test_accuracies)\n",
    "vit_mini_decen_test_accuracies_2 = exponential_moving_average(vit_mini_decen_test_accuracies_2)\n",
    "\n",
    "# Set up the figure and axes for 3 horizontal subplots\n",
    "fig, axes = plt.subplots(1, 2, figsize=(12, 5))\n",
    "\n",
    "labelsize = 14\n",
    "title_fontsize = 18\n",
    "label_fontsize = 18\n",
    "\n",
    "\n",
    "# Plotting the data for each set of accuracies in different subplots\n",
    "axes[0].axhline(y=vit_mini_cen_test_accuracy, color=sns.color_palette()[0], linestyle='--', label='Centralized / 48000 Data')\n",
    "sns.lineplot(ax=axes[0], x=model_size_list, y=vit_mini_decen_test_accuracies,  color=sns.color_palette()[1], legend='brief', label='Federated')\n",
    "axes[0].set_title('Bridge Gap by Scaling Model Size', fontsize=title_fontsize)\n",
    "axes[0].set_xlabel('Model Size', fontsize=label_fontsize)\n",
    "axes[0].set_ylabel('Test Accuracy', fontsize=label_fontsize)\n",
    "axes[0].tick_params(axis='both', which='major', labelsize=labelsize)\n",
    "axes[0].legend(loc='center right')\n",
    "\n",
    "# Ensure x-axis shows only integer values\n",
    "axes[0].xaxis.set_major_locator(ticker.MaxNLocator(integer=True))\n",
    "\n",
    "axes[1].axhline(y=vit_mini_cen_test_accuracy_2, color=sns.color_palette()[0], linestyle='--', label='Centralized / 48000 Data')\n",
    "sns.lineplot(ax=axes[1], x=rounds_list, y=vit_mini_decen_test_accuracies_2,  color=sns.color_palette()[1], legend='brief', label='Federated')\n",
    "axes[1].set_title('Bridge Gap by Increasing Communication Rounds', fontsize=title_fontsize)\n",
    "axes[1].set_xlabel('Number of Rounds', fontsize=label_fontsize)\n",
    "axes[1].tick_params(axis='both', which='major', labelsize=labelsize)\n",
    "axes[1].legend(loc='center right')\n",
    "\n",
    "# Adjust layout\n",
    "plt.tight_layout()\n",
    "\n",
    "# plt.savefig(\"bridge_gap_by_d_T_1.pdf\", format='pdf')\n",
    "\n",
    "# plt.savefig(\"bridge_gap_by_d_T_2.pdf\", format='pdf')\n",
    "\n",
    "# Show the plot\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
