{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Parameter Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "metadata": {},
   "outputs": [],
   "source": [
    "import subprocess\n",
    "import numpy as np\n",
    "import os; os.chdir('../src') # For server\n",
    "from datetime import datetime\n",
    "\n",
    "# Parameters for all methods\n",
    "methods = ['FedProx']\n",
    "mu_values = [0.02, 0.25, 0.5]\n",
    "L = 9.13  # Smoothness constant L used in the formula\n",
    "\n",
    "# Create mu_global_lr_dict\n",
    "mu_global_lr_dict = {mu: (1/mu + 1/L) for mu in mu_values}\n",
    "# Create mu_local_lr_dict\n",
    "mu_local_lr_dict = {mu: (1/(L + mu)) for mu in mu_values}\n",
    "\n",
    "random_seed = 1\n",
    "R_combined = [2.0]\n",
    "\n",
    "num_clients = 1000\n",
    "sampling_rate = 0.1\n",
    "num_samples = 200\n",
    "input_dim = 10\n",
    "num_classes = 2\n",
    "local_epochs = 150 ### TEST: try more less local update for plotting\n",
    "data_dir = '../data/fedprox_syndata/test'\n",
    "output_data_dir = '../results/test'\n",
    "stopping_threshold = -1\n",
    "num_rounds = 100 \n",
    "\n",
    "# Parameters for data generation\n",
    "num_devices = num_clients\n",
    "x_dim = input_dim\n",
    "b_dim = num_classes"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run Experiment And Collect Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "start_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S')\n",
    "print(f\"Start time: {start_time}\")\n",
    "\n",
    "# Loop through each choice of R\n",
    "for R in R_combined:\n",
    "    # Generate data for the current R with specified dimensions and number of devices\n",
    "    print(f\"Generating data for R={R} with {num_devices} devices...\")\n",
    "    subprocess.run(f\"python simulation/generate_data.py --R {R} --num_devices {num_devices} --n_samples {num_samples} \"\n",
    "                   f\"--x_dim {x_dim} --b_dim {b_dim} --seed {random_seed} \"\n",
    "                   f\"--output_dir {data_dir}\", shell=True, check=True)\n",
    "\n",
    "    # Loop through each method\n",
    "    for method in methods:\n",
    "        if method == 'FedProx':\n",
    "            # For FedProx, iterate over different mu values\n",
    "            for mu in mu_global_lr_dict.keys():\n",
    "                global_lr_fedprox = mu_global_lr_dict[mu]\n",
    "                local_lr_fedprox = mu_local_lr_dict[mu]\n",
    "                # Single run for FedProx with specified mu\n",
    "                print(f\"Running {method} with mu={mu} for R={R}...\")\n",
    "                subprocess.run(f\"python simulation/simulation_main.py --method {method} --num_clients {num_clients} \"\n",
    "                               f\"--lr_global {global_lr_fedprox} --lr_local {local_lr_fedprox} --mu {mu} \"\n",
    "                               f\"--input_dim {input_dim} --num_classes {num_classes} --local_epochs {local_epochs} \"\n",
    "                               f\"--data_dir {data_dir} --output_data_dir {output_data_dir} \"\n",
    "                               f\"--stopping_threshold {stopping_threshold} --num_rounds {num_rounds} \"\n",
    "                               f\"--R {R} --record_error stat --sampling_rate {sampling_rate}\", shell=True, check=True)\n",
    "            \n",
    "\n",
    "end_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S')\n",
    "print(f\"End time: {end_time}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Visualization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# FOR Total Error Visualizaiton\n",
    "# Convert the lists into space-separated strings\n",
    "R_value = R_combined[0]\n",
    "mu_values_str = ' '.join(map(str, mu_values))\n",
    "methods_str = ' '.join(methods)\n",
    "\n",
    "# Command for subprocess\n",
    "command = f\"python simulation/visualize_total_error.py --output_data_dir {output_data_dir} --R_value {R_value} --methods {methods_str} --mu_values {mu_values_str} --local_epochs {local_epochs}\"\n",
    "\n",
    "# Run the command using subprocess\n",
    "subprocess.run(command, shell=True, check=True)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python (fedprox)",
   "language": "python",
   "name": "fedprox"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
