{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "72f3132a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "import csv\n",
    "import os\n",
    "from os import listdir\n",
    "from numpy import genfromtxt\n",
    "plt.rcParams.update({'font.size': 12})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6bb3cd2e",
   "metadata": {},
   "outputs": [],
   "source": [
    "tab_f = [0,30]\n",
    "tab_n_sampled = [i*5+1 for i in range(12)]\n",
    "nb_runs = 15\n",
    "tab_acc = np.zeros((len(tab_f), len(tab_n_sampled), 15, 500))\n",
    "for i, f in enumerate(tab_f):\n",
    "    for j, n_sampled in enumerate(tab_n_sampled):\n",
    "        folder_name = './save_fig2/T_500_n_150_f_'+str(f)+'_n_sampled_'+str(n_sampled)+'_server_lr_1_workers_lr_0.1_milestones_[300, 400, 460]_gamma_0.2_attack_SF_agg_trmean_batch_size_8_nb_local_steps_10_device_cuda_nb_run_5_seed_1'\n",
    "        for run in range(nb_runs):\n",
    "            tab_acc[i][j][run] = genfromtxt(\"./\"+folder_name+\"/data/Accuracies_\"+str(run)+\".csv\", delimiter=',')\n",
    "\n",
    "curve1 = np.array([np.mean(tab_acc[0], axis = 1)[:,-1][i] for i in np.arange(12)*5])\n",
    "curve2 = np.array([np.mean(tab_acc[1], axis = 1)[:,-1][i] for i in np.arange(12)*5])\n",
    "\n",
    "err1 = np.array([1.96*np.std(tab_acc[0], axis = 1)[:,-1][i] for i in np.arange(12)*5])\n",
    "err2 = np.array([1.96*np.std(tab_acc[1], axis = 1)[:,-1][i] for i in np.arange(12)*5])\n",
    "\n",
    "plt.plot(np.arange(12)*5, curve1, label ='Without Byzantine clients',marker = 's', linestyle='-', color = (0, 0.4470, 0.7410))\n",
    "plt.plot(np.arange(12)*5, curve2,  label ='With 20% of Byzantine clients', marker = '^', linestyle='-.', color = (0.8500, 0.3250, 0.0980))\n",
    "\n",
    "error_inf1 = curve1-err1\n",
    "error_inf2 = curve2-err2\n",
    "\n",
    "error_inf1[error_inf1<0] = 0\n",
    "error_inf2[error_inf2<0] = 0\n",
    "\n",
    "plt.fill_between(np.arange(12)*5, error_inf1, curve1+err1, alpha = 0.15)\n",
    "plt.fill_between(np.arange(12)*5, error_inf2, curve2+err2, alpha = 0.15)\n",
    "\n",
    "plt.xlabel('Number of sampled clients')\n",
    "plt.ylabel('Accuracy')\n",
    "plt.legend(loc='lower right')\n",
    "\n",
    "plt.grid()\n",
    "plt.xlim(0,55)\n",
    "plt.ylim(0,95)\n",
    "\n",
    "plt.savefig('fig2.pdf')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
