{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Parameter Setup and Experiments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import subprocess\n",
    "import numpy as np\n",
    "import os; os.chdir('../src') # For server\n",
    "from datetime import datetime\n",
    "\n",
    "# Parameters for all methods\n",
    "methods = ['FedAvg', 'LocalTrain', 'FedProx']\n",
    "mu_values = [0.02, 0.1, 0.5]\n",
    "L = 9.13  # Smoothness constant L used in the formula\n",
    "\n",
    "# Create mu_global_lr_dict\n",
    "mu_global_lr_dict = {mu: (1/mu + 1/L) for mu in mu_values}\n",
    "# Create mu_local_lr_dict\n",
    "mu_local_lr_dict = {mu: (1/(L + mu)) for mu in mu_values}\n",
    "lr_local_localtrain = 0.001 # lr_local * mu < 1\n",
    "lr_local_fedavg = 0.01\n",
    "\n",
    "random_seed = 1\n",
    "# FOR ATR\n",
    "R_values_0_to_2 = np.linspace(0.1, 3, 10).tolist()\n",
    "# Combined\n",
    "R_combined = R_values_0_to_2\n",
    "\n",
    "num_clients = 10\n",
    "num_samples = 200\n",
    "input_dim = 10\n",
    "num_classes = 2\n",
    "local_epochs = 150\n",
    "data_dir = '../data/fedprox_syndata/test'\n",
    "output_data_dir = '../results/test'\n",
    "stopping_threshold = -1\n",
    "num_rounds = 50  # Number of rounds, LOCAL TRAIN convergence issue\n",
    "local_train_base_rounds = 500 #### TEST\n",
    "\n",
    "# Parameters for data generation\n",
    "num_devices = num_clients\n",
    "x_dim = input_dim\n",
    "b_dim = num_classes\n",
    "\n",
    "start_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S')\n",
    "print(f\"Start time: {start_time}\")\n",
    "\n",
    "# Loop through each choice of R\n",
    "for R in R_combined:\n",
    "    # Generate data for the current R with specified dimensions and number of devices\n",
    "    print(f\"Generating data for R={R} with {num_devices} devices...\")\n",
    "    subprocess.run(f\"python simulation/generate_data.py --R {R} --num_devices {num_devices} --n_samples {num_samples} \"\n",
    "                   f\"--x_dim {x_dim} --b_dim {b_dim} --seed {random_seed} \"\n",
    "                   f\"--output_dir {data_dir}\", shell=True, check=True)\n",
    "\n",
    "    # Loop through each method\n",
    "    for method in methods:\n",
    "        if method == 'FedProx':\n",
    "            # For FedProx, iterate over different mu values\n",
    "            for mu in mu_global_lr_dict.keys():\n",
    "                global_lr_fedprox = mu_global_lr_dict[mu]\n",
    "                local_lr_fedprox = mu_local_lr_dict[mu]\n",
    "                # Single run for FedProx with specified mu\n",
    "                print(f\"Running {method} with mu={mu} for R={R}...\")\n",
    "                subprocess.run(f\"python simulation/simulation_main.py --method {method} --num_clients {num_clients} \"\n",
    "                               f\"--lr_global {global_lr_fedprox} --lr_local {local_lr_fedprox} --mu {mu} \"\n",
    "                               f\"--input_dim {input_dim} --num_classes {num_classes} --local_epochs {local_epochs} \"\n",
    "                               f\"--data_dir {data_dir} --output_data_dir {output_data_dir} \"\n",
    "                               f\"--stopping_threshold {stopping_threshold} --num_rounds {num_rounds} \"\n",
    "                               f\"--R {R} --record_error stat\", shell=True, check=True)\n",
    "        else:\n",
    "            if method == 'LocalTrain':\n",
    "                local_train_num_rounds = local_train_base_rounds # + int(2**R * 100) # No longer give more iterations\n",
    "                # For FedAvg and LocalTrain\n",
    "                print(f\"Running {method} for R={R}...\")\n",
    "                subprocess.run(f\"python simulation/simulation_main.py --method {method} --num_clients {num_clients} \"\n",
    "                               f\"--lr_local {lr_local_localtrain} \"\n",
    "                               f\"--input_dim {input_dim} --num_classes {num_classes} --local_epochs {local_epochs} \"\n",
    "                               f\"--data_dir {data_dir} --output_data_dir {output_data_dir} \"\n",
    "                               f\"--stopping_threshold {stopping_threshold} --num_rounds {local_train_num_rounds} \"\n",
    "                               f\"--R {R} --record_error stat\", shell=True, check=True)\n",
    "            else:\n",
    "                print(f\"Running {method} for R={R}...\")\n",
    "                subprocess.run(f\"python simulation/simulation_main.py --method {method} --num_clients {num_clients} \"\n",
    "                               f\"--lr_local {lr_local_fedavg} \"\n",
    "                               f\"--input_dim {input_dim} --num_classes {num_classes} --local_epochs {local_epochs} \"\n",
    "                               f\"--data_dir {data_dir} --output_data_dir {output_data_dir} \"\n",
    "                               f\"--stopping_threshold {stopping_threshold} --num_rounds {num_rounds} \"\n",
    "                               f\"--R {R} --record_error stat\", shell=True, check=True)\n",
    "            \n",
    "\n",
    "end_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S')\n",
    "print(f\"End time: {end_time}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Visualize the Result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# FOR ATR\n",
    "# Convert the lists into space-separated strings\n",
    "R_values_str = ' '.join(map(str, R_values_0_to_2))\n",
    "mu_values_str = ' '.join(map(str, mu_values))\n",
    "methods_str = ' '.join(methods)\n",
    "\n",
    "# Command for subprocess\n",
    "command = f\"python simulation/visualize_stat_error.py --output_data_dir {output_data_dir} --R_values {R_values_str} --methods {methods_str} --mu_values {mu_values_str}\"\n",
    "\n",
    "# Run the command using subprocess\n",
    "subprocess.run(command, shell=True, check=True)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python (fedprox)",
   "language": "python",
   "name": "fedprox"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
