{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import matplotlib as mpl\n",
    "\n",
    "import yaml\n",
    "\n",
    "def read_one_block_of_yaml_data(filename):\n",
    "    with open(f'{filename}.yaml','r') as f:\n",
    "        output = yaml.safe_load(f)\n",
    "    return output\n",
    "    \n",
    "\n",
    "config_file = 'sweep/ACP'\n",
    "\n",
    "config = read_one_block_of_yaml_data(config_file)\n",
    "params = config['parameters']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "size=19\n",
    "size_legend=18\n",
    "mpl.rcParams.update({\n",
    "    \"pgf.texsystem\": \"pdflatex\",\n",
    "    'font.family': 'serif',\n",
    "    'font.serif': 'Times',\n",
    "    'font.weight':'bold',\n",
    "    #'legend.fontweight':'bold',\n",
    "    'text.usetex': True,\n",
    "    'pgf.rcfonts': False,\n",
    "    \"axes.grid\" : True,\n",
    "    'font.size': size,\n",
    "    'axes.labelsize':size,\n",
    "    'axes.titlesize':size,\n",
    "    'figure.titlesize':size,\n",
    "    'xtick.labelsize':size,\n",
    "    'ytick.labelsize':size,\n",
    "    'legend.fontsize':size_legend\n",
    "})\n",
    "palette = sns.color_palette('colorblind') # \n",
    "colors = colors = plt.rcParams['axes.prop_cycle'].by_key()['color']# [palette[i] for i in range(5)] # ['#383F51','#B0413E','#FEA82F','#43AA8B','#6C7D47']# \n",
    "markers = ['o', 'v', 's', '*', 'd']\n",
    "dashstyle = ['-', '--', '-.', ':', ':']\n",
    "c=[]\n",
    "m=[]\n",
    "\n",
    "data_directory = \"results-data\"\n",
    "plot_directory = \"results-plot\"\n",
    "\n",
    "attack_name = {\n",
    "          \"FOE\":\"FOE\", \n",
    "          \"ALIE\":\"ALIE\", \n",
    "          \"Dissension\":\"Dissensus\", \n",
    "          \"SP Heterogeneity\":\"SpH\"\n",
    "        }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def aggregate_experiment(name_experiment, seeds, data_directory=data_directory):\n",
    "    \"\"\"\n",
    "    aggregate values of the experiments, output the min, max, mean, std of each experiment batch.\n",
    "    \"\"\"\n",
    "    df_list = [pd.read_csv(data_directory + \"/\" + name_experiment + \"_\" +str(seed) + \".csv\", delimiter=',') for seed in seeds]\n",
    "\n",
    "    stacked_df = np.stack([df[['loss_train', 'variance']].values for df in df_list], axis=2)\n",
    "\n",
    "    # Calculate mean, std, min, and max across the third dimension (the DataFrame axis)\n",
    "    mean_loss = stacked_df[:, 0, :].mean(axis=1)\n",
    "    std_loss = stacked_df[:, 0, :].std(axis=1)\n",
    "    min_loss = stacked_df[:, 0, :].min(axis=1)\n",
    "    max_loss = stacked_df[:, 0, :].max(axis=1)\n",
    "\n",
    "    mean_var = stacked_df[:, 1, :].mean(axis=1)\n",
    "    std_var = stacked_df[:, 1, :].std(axis=1)\n",
    "    min_var = stacked_df[:, 1, :].min(axis=1)\n",
    "    max_var = stacked_df[:, 1, :].max(axis=1)\n",
    "\n",
    "    # Create a new DataFrame to store the results, with \"# Step number\" as index\n",
    "    result_df = pd.DataFrame({\n",
    "        'mean_var': mean_var/mean_var[0],\n",
    "        'std_var':std_var/std_var[0],\n",
    "        'min_var':min_var/min_var[0],\n",
    "        'max_var':max_var/max_var[0],\n",
    "        'mean_loss': mean_loss/mean_loss[0],\n",
    "        'std_loss': std_loss/std_loss[0],\n",
    "        'min_loss': min_loss/min_loss[0],\n",
    "        'max_loss': max_loss/max_loss[0]\n",
    "    }, index=df_list[0][\"iteration\"])\n",
    "\n",
    "    return result_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "for attack in params['attack']['values']:\n",
    "    for nb_byzantine_neighbors in params['nb_byzantine_neighbors']['values']:    \n",
    "        fig, axs = plt.subplots(1, 3, figsize=(14, 5), sharex=False, sharey=False)\n",
    "\n",
    "        for rule_id, communication_rule in enumerate(params['communication_rule']['values']):\n",
    "            name = f\"{params['task']['value']}_{communication_rule}_{attack}_{params['topology']['value']}_{params['nb_honest']['value']}h_{nb_byzantine_neighbors}byz_{params['nb_iterations']['value']}iter\"\n",
    "            seeds = [ s for s in params['seed'][\"values\"]]\n",
    "            \n",
    "            try:\n",
    "                df = aggregate_experiment(name, seeds)\n",
    "            except:\n",
    "                print(name, 'not found')\n",
    "                continue\n",
    "            \n",
    "            axs[0].plot(df.index, df['mean_loss'], color=colors[rule_id], label=communication_rule, linestyle=dashstyle[rule_id])\n",
    "            axs[0].fill_between(df.index,df['min_loss'], df['max_loss'], facecolor=colors[rule_id], alpha=0.2)\n",
    "\n",
    "            axs[1].plot(df.index,df['mean_var'], color=colors[rule_id], label=communication_rule, linestyle=dashstyle[rule_id])\n",
    "            # axs[1].fill_between(df.index,df['min_var'], df['max_var'], facecolor=colors[rule_id], alpha=0.2)\n",
    "\n",
    "            axs[2].plot(df.index,df['mean_loss'] - df['mean_var'], color=colors[rule_id], label=communication_rule, linestyle=dashstyle[rule_id])\n",
    "            # axs[2].fill_between(df.index,df['min_loss'] - df['max_var'], df['max_loss'] - df['min_var'], facecolor=colors[rule_id], alpha=0.2)\n",
    "        \n",
    "\n",
    "        \n",
    "        # plt.title(name_plot)\n",
    "        axs[0].set_title(r\"MSE$^t/$MSE$^0$\")\n",
    "        axs[1].set_title(\"Variance\")\n",
    "        axs[2].set_title(\"Bias\")\n",
    "\n",
    "        axs[0].set_xlabel(r\"Bias\")\n",
    "\n",
    "        axs[0].set_yscale(\"log\")\n",
    "        axs[1].set_yscale(\"log\")\n",
    "        # plt.ylim(1e-2,2e1)\n",
    "\n",
    "        handles, labels = axs[0].get_legend_handles_labels()\n",
    "        fig.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.52, 0.10), ncol=3, labelspacing=0.1, handletextpad=0.1, borderaxespad=0)\n",
    "        fig.tight_layout(w_pad=0.)\n",
    "\n",
    "        name_plot = f\"{params['task']['value']}_{attack_name[attack]}_{params['topology']['value']}_{params['nb_honest']['value']}h_{nb_byzantine_neighbors}byz\"\n",
    "        fig.savefig(plot_directory + '/'+ name_plot+'.pdf',bbox_inches='tight')\n",
    "        print(name_plot)\n",
    "        plt.show()\n",
    "\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#### For each nb_honest, nb_byz, attack, defense, plot all runs on one plot\n",
    "\n",
    "for attack in params['attack']['values']:\n",
    "    for nb_byzantine_neighbors in params['nb_byzantine_neighbors']['values']:    \n",
    "        fig, axs = plt.subplots(1, 3, figsize=(14, 5), sharex=False, sharey=False)\n",
    "\n",
    "        for rule_id, communication_rule in enumerate(params['communication_rule']['values']):\n",
    "            name = f\"{params['task']['value']}_{communication_rule}_{attack}_{params['topology']['value']}_{params['nb_honest']['value']}h_{nb_byzantine_neighbors}byz_{params['nb_iterations']['value']}iter\"\n",
    "            seeds = [ s for s in params['seed'][\"values\"]]\n",
    "            \n",
    "            try:\n",
    "                df = aggregate_experiment(name, seeds)\n",
    "            except:\n",
    "                print(name, 'not found')\n",
    "                continue\n",
    "            \n",
    "            axs[0].plot(df.index,df['mean_loss'], color=colors[rule_id], label=communication_rule, linestyle=dashstyle[rule_id])\n",
    "            # axs[0].fill_between(df.index,df['min_loss'], df['max_loss'], facecolor=colors[rule_id], alpha=0.2)\n",
    "\n",
    "            axs[1].plot(df.index,df['mean_var'], color=colors[rule_id], label=communication_rule, linestyle=dashstyle[rule_id])\n",
    "            # axs[1].fill_between(df.index,df['min_var'], df['max_var'], facecolor=colors[rule_id], alpha=0.2)\n",
    "\n",
    "            axs[2].plot(df.index,df['mean_loss'] - df['mean_var'], color=colors[rule_id], label=communication_rule, linestyle=dashstyle[rule_id])\n",
    "            # axs[2].fill_between(df.index,df['min_loss'] - df['max_var'], df['max_loss'] - df['min_var'], facecolor=colors[rule_id], alpha=0.2)\n",
    "        \n",
    "\n",
    "        \n",
    "        # plt.title(name_plot)\n",
    "        axs[0].set_title(\"Error\")\n",
    "        axs[1].set_title(\"Variance\")\n",
    "        axs[2].set_title(\"Bias\")\n",
    "\n",
    "        axs[2].set_xlabel(r\"Iteration\")\n",
    "\n",
    "        axs[0].set_yscale(\"log\")\n",
    "        axs[1].set_yscale(\"log\")\n",
    "        # plt.ylim(1e-2,2e1)\n",
    "\n",
    "        handles, labels = axs[0].get_legend_handles_labels()\n",
    "        fig.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.52, 0.10), ncol=3, labelspacing=0.1, handletextpad=0.1, borderaxespad=0)\n",
    "        fig.tight_layout(w_pad=0.)\n",
    "\n",
    "        name_plot = f\"{params['task']['value']}_{attack_name[attack]}_{params['topology']['value']}_{params['nb_honest']['value']}h_{nb_byzantine_neighbors}byz\"\n",
    "        fig.savefig(plot_directory + '/'+ name_plot+'.pdf',bbox_inches='tight')\n",
    "        print(name_plot)\n",
    "        plt.show()\n",
    "\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#### Make a plot with all runs\n",
    "\n",
    "attacks = params['attack']['values']\n",
    "gars = params['communication_rule']['values']\n",
    "byzcounts = params['nb_byzantine_neighbors']['values']\n",
    "\n",
    "\n",
    "fig, axs = plt.subplots(1, len(attacks), figsize=(14, 5), sharex=False, sharey=True)\n",
    "\n",
    "for attack_id, attack in enumerate(attacks):\n",
    "    for rule_id, communication_rule in enumerate(gars):\n",
    "        X=[]\n",
    "        Y=[]\n",
    "        Yerr=[]\n",
    "        Ymin=[]\n",
    "        Ymax=[]\n",
    "        for byz in byzcounts:\n",
    "            name = f\"{params['task']['value']}_{communication_rule}_{attack}_{params['topology']['value']}_{params['nb_honest']['value']}h_{byz}byz_{params['nb_iterations']['value']}iter\"\n",
    "            seeds = [ s for s in params['seed'][\"values\"]]\n",
    "            try:\n",
    "                df = aggregate_experiment(name, seeds)\n",
    "            except:\n",
    "                print(name, 'not found')\n",
    "                continue\n",
    "            X.append(byz)\n",
    "\n",
    "            y = df.loc[params[\"nb_iterations\"]['value'], \"mean_loss\"]\n",
    "                \n",
    "            if np.isnan(y):\n",
    "                Y.append(1000)\n",
    "            else:\n",
    "                Y.append(y)\n",
    "            \n",
    "            Yerr.append(df.loc[params[\"nb_iterations\"]['value'], \"std_loss\"])\n",
    "            Ymin.append(df.loc[params[\"nb_iterations\"]['value'], \"min_loss\"])\n",
    "            Ymax.append(df.loc[params[\"nb_iterations\"]['value'], \"max_loss\"])\n",
    "        \n",
    "        Y=np.array(Y)\n",
    "        Yerr=np.array(Yerr)\n",
    "\n",
    "        axs[attack_id].plot(X,Y, color=colors[rule_id], marker=markers[rule_id], label=communication_rule, linestyle=dashstyle[rule_id])\n",
    "        axs[attack_id].fill_between(X, Ymin, Ymax, facecolor=colors[rule_id], alpha=0.2)\n",
    "\n",
    "    axs[attack_id].set_title(attack_name[attack])\n",
    "    axs[attack_id].set_xlabel(r\"b\")\n",
    "    axs[attack_id].set_xticks(ticks=byzcounts)\n",
    "axs[0].set_ylabel(r\"$MSE^{T}/MSE^0$\",rotation=90,size=\"large\")\n",
    "axs[0].set_yscale(\"log\")\n",
    "axs[0].set_yticks(ticks=[1, 1e-2, 1e-4])\n",
    "axs[0].set_ylim(0.5*1e-5,3)\n",
    "\n",
    "\n",
    "handles, labels = axs[0].get_legend_handles_labels()\n",
    "fig.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.52, 0.05), ncol=5, labelspacing=0.1, handletextpad=0.1, borderaxespad=0)\n",
    "fig.tight_layout(w_pad=0.2)\n",
    "\n",
    "name_plot = f\"byz_test_{params['task']['value']}_{attack}_{params['topology']['value']}_{params['nb_honest']['value']}h\"\n",
    "fig.savefig(plot_directory + '/'+ name_plot+'.pdf',bbox_inches='tight')\n",
    "print(name_plot)\n",
    "plt.show()\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "        "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "robdec",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
