{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Parameter Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import os\n",
    "from datetime import datetime\n",
    "\n",
    "local_update_number = 5\n",
    "model = 'mnist' # mnist or mnist_cnn\n",
    "\n",
    "# Move to the src directory\n",
    "os.chdir('../src')\n",
    "\n",
    "# Main execution block\n",
    "fedprox_mu_values = [0.5, 1.5, 2.5]  # Specify mu values for FedProx\n",
    "\n",
    "# Automatically create the methods to evaluate based on the given mu values for FedProx\n",
    "def create_methods_to_evaluate(fedprox_mu_values, fedprox_exclusive=False):\n",
    "    \"\"\"Create a list of methods to evaluate, including FedAvg, LocalTrain, and FedProx with specific mu values.\"\"\"\n",
    "    methods_to_evaluate = [\"FedAvg\", \"LocalTrain\"] if not fedprox_exclusive else []\n",
    "    for mu in fedprox_mu_values:\n",
    "        methods_to_evaluate.append(f\"FedProx_mu{mu}\")\n",
    "    return methods_to_evaluate\n",
    "\n",
    "# Specify the methods you want to evaluate\n",
    "methods_to_evaluate = create_methods_to_evaluate(fedprox_mu_values)\n",
    "methods_fedprox = create_methods_to_evaluate(fedprox_mu_values, fedprox_exclusive=True)\n",
    "\n",
    "# Set the directory where the result files are loaded and stored\n",
    "result_directory = f'../results/{model}/E{local_update_number}' "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Collect Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import re\n",
    "import pandas as pd\n",
    "\n",
    "# Define the metrics we're interested in\n",
    "metric_names = [\n",
    "    \"Global Train Loss\",\n",
    "    \"Global Test Loss\",\n",
    "    \"Global Test Accuracy\",\n",
    "    \"Global Evaluation Loss\",\n",
    "    \"Global Evaluation Accuracy\",\n",
    "    \"Average Evaluation Loss\",\n",
    "    \"Average Evaluation Accuracy\",\n",
    "    \"Average Naive Baseline Loss\",\n",
    "    \"Average Naive Baseline Accuracy\",\n",
    "    \"Average Evaluation Improvement Over Baseline\",\n",
    "    \"Average Training Loss\",\n",
    "    \"Average Training Accuracy\",\n",
    "    \"Average Test Loss\",\n",
    "    \"Average Test Accuracy\"\n",
    "]\n",
    "\n",
    "\n",
    "# Create a dictionary to store the results, with the structure:\n",
    "# {client_classes: {metric: {method: value}}}\n",
    "results = {}\n",
    "\n",
    "# Function to extract relevant data from the file content\n",
    "def extract_results_from_file(file_path):\n",
    "    with open(file_path, 'r') as f:\n",
    "        lines = f.readlines()\n",
    "    \n",
    "    # Define the regex patterns for the rows of interest\n",
    "    patterns = {\n",
    "        \"Global Train Loss\": r\"Global Train Loss across all clients: (.+)\",\n",
    "        \"Global Test Loss\": r\"Global Test Loss: (.+)\",\n",
    "        \"Global Test Accuracy\": r\"Global Test Accuracy: (.+)\",\n",
    "        \"Global Evaluation Loss\": r\"Global Evaluation Loss: (.+)\",\n",
    "        \"Global Evaluation Accuracy\": r\"Global Evaluation Accuracy: (.+)\",\n",
    "        \"Average Evaluation Loss\": r\"Average evaluation loss across clients: (.+)\",\n",
    "        \"Average Evaluation Accuracy\": r\"Average evaluation accuracy across clients: (.+)\",\n",
    "        \"Average Naive Baseline Loss\": r\"Average naive baseline loss across clients: (.+)\",\n",
    "        \"Average Naive Baseline Accuracy\": r\"Average naive baseline accuracy across clients: (.+)\",\n",
    "        \"Average Evaluation Improvement Over Baseline\": r\"Average improvement over naive baseline: (.+)\",\n",
    "        \"Average Training Loss\": r\"Average training loss across clients: (.+)\",\n",
    "        \"Average Training Accuracy\": r\"Average training accuracy across clients: (.+)\",\n",
    "        \"Average Test Loss\": r\"Average test loss across clients: (.+)\",\n",
    "        \"Average Test Accuracy\": r\"Average test accuracy across clients: (.+)\"\n",
    "    }\n",
    "\n",
    "    # Extract the relevant data\n",
    "    extracted_data = {}\n",
    "    for line in lines:\n",
    "        for key, pattern in patterns.items():\n",
    "            match = re.search(pattern, line)\n",
    "            if match:\n",
    "                extracted_data[key] = match.group(1)  # Extract the first matched group\n",
    "\n",
    "    return extracted_data\n",
    "\n",
    "# Function to process all result files and populate the 'results' dictionary\n",
    "def process_results_in_directory(directory, methods_to_evaluate):\n",
    "    for file_name in os.listdir(directory):\n",
    "        if file_name.endswith(\".txt\"):\n",
    "            # Extract method and client classes from filename (assuming consistent naming convention)\n",
    "            method_match = re.search(r\"output_log_(.+?)(_mu[\\d.]+)?_(\\d+)cls.txt\", file_name)\n",
    "            if method_match:\n",
    "                method = method_match.group(1)  # Method name (e.g., FedAvg, LocalTrain)\n",
    "                mu_value = method_match.group(2)  # Optional mu value (e.g., _mu0.1)\n",
    "                client_classes = int(method_match.group(3))  # Client classes\n",
    "                \n",
    "                # If mu_value is not None, append it to the method name\n",
    "                if mu_value:\n",
    "                    method += mu_value\n",
    "\n",
    "                # Only process files for methods we're interested in\n",
    "                if method not in methods_to_evaluate:\n",
    "                    continue\n",
    "\n",
    "                # Full path to the file\n",
    "                file_path = os.path.join(directory, file_name)\n",
    "\n",
    "                # Extract results from the file\n",
    "                extracted_data = extract_results_from_file(file_path)\n",
    "\n",
    "                # Add data to the results dictionary\n",
    "                if client_classes not in results:\n",
    "                    results[client_classes] = {metric: {} for metric in metric_names}\n",
    "                \n",
    "                for metric, value in extracted_data.items():\n",
    "                    results[client_classes][metric][method] = value\n",
    "\n",
    "# Process all the result files for the specified methods\n",
    "process_results_in_directory(result_directory, methods_to_evaluate)\n",
    "\n",
    "# Create a DataFrame to store the final output\n",
    "data = []\n",
    "\n",
    "# Loop over the results dictionary to populate the DataFrame rows\n",
    "for client_classes, metrics in results.items():\n",
    "    for metric, methods in metrics.items():\n",
    "        row = [client_classes, metric]\n",
    "        for method in methods_to_evaluate:\n",
    "            row.append(methods.get(method, None))  # Get the method value or None if missing\n",
    "        data.append(row)\n",
    "\n",
    "# Create the DataFrame with columns for client classes, metric, and each method\n",
    "columns = [\"client_classes\", \"metric\"] + methods_to_evaluate\n",
    "df = pd.DataFrame(data, columns=columns)\n",
    "\n",
    "# Display the table\n",
    "print(df)\n",
    "\n",
    "# Optionally, save to a CSV\n",
    "df.to_csv('../results/mnist/summary_results.csv', index=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Visualization"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Statistical Error"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import re\n",
    "import os\n",
    "\n",
    "def plot_metric_vs_client_classes(results, metric_name, methods_to_evaluate, output_dir, selected_client_classes=None, \n",
    "                                  x_label=\"Statistical Heterogeneity (# Classes)\", y_label = None, flip_x_axis=False):\n",
    "    \"\"\"\n",
    "    Plots the specified metric versus client classes for different methods, with the option to flip the x-axis.\n",
    "\n",
    "    Args:\n",
    "        results (dict): A dictionary containing the results for each method and client class.\n",
    "        metric_name (str): The metric to plot (e.g., \"Global Train Loss\").\n",
    "        methods_to_evaluate (list): A list of method names to include in the plot.\n",
    "        output_dir (str): Directory where to save the plots.\n",
    "        selected_client_classes (list, optional): A list of client classes to include in the plot. \n",
    "                                                  If None, all client classes will be used.\n",
    "        x_label (str): Label for the x-axis (default is \"Number of Classes per Client\").\n",
    "        flip_x_axis (bool): If True, invert the x-axis so larger numbers appear on the left.\n",
    "    \"\"\"\n",
    "    # Set the figure size, marker size, font size, and transparency for real data plots\n",
    "    fig_size = (4, 4)\n",
    "    marker_size = 4\n",
    "    font_size = 10\n",
    "    legend_font_size = 8\n",
    "    default_alpha = 1.0\n",
    "    grid_alpha = 0.3\n",
    "\n",
    "    # Get colors from the 'tab10' color palette\n",
    "    tab10_colors = plt.get_cmap('tab10').colors\n",
    "\n",
    "    # Define method styles with colors and markers\n",
    "    method_styles = {\n",
    "        'FedAvg': {'color': tab10_colors[0], 'linestyle': '--', 'marker': 'o', 'label': 'FedAvg', 'alpha': 0.9},  # Blue\n",
    "        'LocalTrain': {'color': tab10_colors[7], 'linestyle': '--', 'marker': 's', 'label': 'LocalTrain', 'alpha': default_alpha}  # Grey\n",
    "    }\n",
    "\n",
    "    # Extract FedProx methods with mu values\n",
    "    fedprox_methods = [method for method in methods_to_evaluate if 'FedProx' in method]\n",
    "    mu_values = []\n",
    "\n",
    "    for method in fedprox_methods:\n",
    "        match = re.search(r'_mu([0-9.]+)', method)\n",
    "        if match:\n",
    "            mu_values.append((method, float(match.group(1))))\n",
    "\n",
    "    # Sort FedProx methods by their lambda (mu) values\n",
    "    mu_values.sort(key=lambda x: x[1])\n",
    "    mu_labels = ['small lambda', 'medium lambda', 'large lambda']\n",
    "    fedprox_colors = [tab10_colors[2], tab10_colors[1], tab10_colors[3]]  # Green, Orange, Red\n",
    "    fedprox_markers = ['D', '*', 'x']  # Diamond, Star, Multiply\n",
    "\n",
    "    # Assign styles to FedProx methods\n",
    "    for (method, _), label, color, marker in zip(mu_values, mu_labels, fedprox_colors, fedprox_markers):\n",
    "        alpha = 0.9 if label == 'large lambda' else default_alpha\n",
    "        method_styles[method] = {\n",
    "            'color': color, 'linestyle': '-', 'marker': marker, 'label': f'FedProx ({label})', 'alpha': alpha\n",
    "        }\n",
    "\n",
    "    # If selected_client_classes is not provided, use all available client classes\n",
    "    client_classes_list = sorted(selected_client_classes) if selected_client_classes else sorted(results.keys())\n",
    "    \n",
    "    # Create a dictionary to hold metric data for each method\n",
    "    metric_data = {method: [] for method in methods_to_evaluate}\n",
    "\n",
    "    # Extract the specified metric for each selected client class and method\n",
    "    for client_classes in client_classes_list:\n",
    "        for method in methods_to_evaluate:\n",
    "            # Get the metric value for the current client classes and method\n",
    "            metric_value = results.get(client_classes, {}).get(metric_name, {}).get(method, None)\n",
    "            # Append the value or None if it's missing\n",
    "            metric_data[method].append(float(metric_value) if metric_value is not None else None)\n",
    "\n",
    "    # Plot the specified metric for different methods\n",
    "    plt.figure(figsize=fig_size)\n",
    "    for method in methods_to_evaluate:\n",
    "        style = method_styles.get(method, {})\n",
    "        plt.plot(client_classes_list, metric_data[method], color=style['color'], linestyle=style['linestyle'],\n",
    "                 marker=style['marker'], markersize=marker_size, label=style['label'], alpha=style['alpha'])\n",
    "\n",
    "    plt.xlabel(x_label, fontsize=font_size)\n",
    "    plt.ylabel(metric_name if y_label is None else y_label, fontsize=font_size)\n",
    "\n",
    "    # Remove top and right borders\n",
    "    plt.gca().spines['top'].set_visible(False)\n",
    "    plt.gca().spines['right'].set_visible(False)\n",
    "\n",
    "    # Invert the x-axis if flip_x_axis is True\n",
    "    if flip_x_axis:\n",
    "        plt.gca().invert_xaxis()\n",
    "\n",
    "    # Position the legend below the plot\n",
    "    plt.legend(fontsize=legend_font_size, loc='upper center', bbox_to_anchor=(0.5, -0.25), ncol=2)\n",
    "\n",
    "    # Add a grid with transparency\n",
    "    plt.grid(True, alpha=grid_alpha)\n",
    "\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(os.path.join(output_dir, f'{metric_name}_vs_Client_Classes.png'), bbox_inches='tight')\n",
    "    plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "plotting_dir = os.path.join(result_directory, \"figures\")\n",
    "os.makedirs(plotting_dir, exist_ok=True)\n",
    "\n",
    "# Plotting \"Global Evaluation Loss\" versus client classes\n",
    "plot_metric_vs_client_classes(\n",
    "    results=results,\n",
    "    metric_name=\"Global Evaluation Loss\", \n",
    "    methods_to_evaluate=methods_to_evaluate,\n",
    "    selected_client_classes=[2, 4, 6, 8, 10],\n",
    "    output_dir=plotting_dir,\n",
    "    flip_x_axis=True\n",
    ")\n",
    "\n",
    "# Plotting \"Average Evaluation Loss\" versus client classes\n",
    "plot_metric_vs_client_classes(\n",
    "    results=results,\n",
    "    metric_name=\"Average Evaluation Loss\", \n",
    "    methods_to_evaluate=methods_to_evaluate,\n",
    "    selected_client_classes=[2, 4, 6, 8, 10],\n",
    "    output_dir=plotting_dir,\n",
    "    flip_x_axis=True\n",
    ")\n",
    "\n",
    "# Plotting \"Global Train Loss\" versus client classes\n",
    "plot_metric_vs_client_classes(\n",
    "    results=results,\n",
    "    metric_name=\"Global Train Loss\", \n",
    "    methods_to_evaluate=methods_to_evaluate,\n",
    "    selected_client_classes=[2, 4, 6, 8, 10],\n",
    "    output_dir=plotting_dir,\n",
    "    flip_x_axis=True\n",
    ")\n",
    "\n",
    "# Plotting \"Average Training Loss\" versus client classes\n",
    "plot_metric_vs_client_classes(\n",
    "    results=results,\n",
    "    metric_name=\"Average Training Loss\", \n",
    "    methods_to_evaluate=methods_to_evaluate,\n",
    "    selected_client_classes=[2, 4, 6, 8, 10],\n",
    "    output_dir=plotting_dir,\n",
    "    flip_x_axis=True\n",
    ")\n",
    "\n",
    "# Plotting \"Global Evaluation Accuracy\" versus client classes\n",
    "plot_metric_vs_client_classes(\n",
    "    results=results,\n",
    "    metric_name=\"Global Evaluation Accuracy\", \n",
    "    methods_to_evaluate=methods_to_evaluate,\n",
    "    selected_client_classes=[2, 4, 6, 8, 10],\n",
    "    output_dir=plotting_dir,\n",
    "    y_label=\"Global Test Accuracy\",\n",
    "    flip_x_axis=True\n",
    ")\n",
    "\n",
    "# Plotting \"Average Evaluation Accuracy\" versus client classes\n",
    "plot_metric_vs_client_classes(\n",
    "    results=results,\n",
    "    metric_name=\"Average Evaluation Accuracy\", \n",
    "    methods_to_evaluate=methods_to_evaluate,\n",
    "    selected_client_classes=[2, 4, 6, 8, 10],\n",
    "    output_dir=plotting_dir,\n",
    "    flip_x_axis=True\n",
    ")\n",
    "\n",
    "# Plotting \"Average Evaluation Improvement Over Baseline\" versus client classes\n",
    "plot_metric_vs_client_classes(\n",
    "    results=results,\n",
    "    metric_name=\"Average Evaluation Improvement Over Baseline\", \n",
    "    methods_to_evaluate=methods_to_evaluate,\n",
    "    selected_client_classes=[2, 4, 6, 8, 10],\n",
    "    output_dir=plotting_dir,\n",
    "    flip_x_axis=True\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Total Error"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import re\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# Function to extract a specific metric (Global Train Loss, Local Train Loss, etc.) from the file content\n",
    "def extract_metric(file_path, metric_name):\n",
    "    \"\"\"\n",
    "    Extract the specified metric from the file content.\n",
    "\n",
    "    Args:\n",
    "        file_path (str): The path to the file containing results.\n",
    "        metric_name (str): The name of the metric to extract.\n",
    "\n",
    "    Returns:\n",
    "        list: A list of values for the specified metric.\n",
    "    \"\"\"\n",
    "    with open(file_path, 'r') as f:\n",
    "        lines = f.readlines()\n",
    "\n",
    "    # Look for the line that contains the metric of interest\n",
    "    for line in lines:\n",
    "        if f\"{metric_name}:\" in line:\n",
    "            # Extract and return the list of comma-separated values as floats\n",
    "            metric_str = line.strip().replace(f\"{metric_name}: \", \"\")\n",
    "            metric_values = list(map(float, metric_str.split(',')))\n",
    "            return metric_values\n",
    "    return []\n",
    "\n",
    "# Function to process the results in the directory and store the specified metric for each method\n",
    "def process_metrics_in_directory(directory, methods_to_evaluate, client_classes, metric_name):\n",
    "    metric_dict = {}\n",
    "\n",
    "    for file_name in os.listdir(directory):\n",
    "        if file_name.endswith(\".txt\"):\n",
    "            # Extract method and client classes from filename (assuming consistent naming convention)\n",
    "            method_match = re.search(r\"output_log_(.+?)_(\\d+)cls.txt\", file_name)\n",
    "            if method_match:\n",
    "                method = method_match.group(1)\n",
    "                file_client_classes = int(method_match.group(2))\n",
    "\n",
    "                # Only process the specified number of client classes\n",
    "                if file_client_classes != client_classes:\n",
    "                    continue\n",
    "\n",
    "                # Only process files for methods we're interested in\n",
    "                if method not in methods_to_evaluate:\n",
    "                    continue\n",
    "\n",
    "                # Full path to the file\n",
    "                file_path = os.path.join(directory, file_name)\n",
    "\n",
    "                # Extract the metric from the file\n",
    "                metric_values = extract_metric(file_path, metric_name)\n",
    "\n",
    "                # Store the metric under the corresponding method\n",
    "                if metric_values:\n",
    "                    metric_dict[method] = metric_values\n",
    "\n",
    "    return metric_dict\n",
    "\n",
    "# Plotting function to plot the specified metric for different methods\n",
    "def plot_metric(metric_dict, metric_name, client_classes, output_dir, ylabel=None):\n",
    "    \"\"\"\n",
    "    Plot the specified metric for each method as a function of communication rounds.\n",
    "\n",
    "    Args:\n",
    "        metric_dict (dict): Dictionary with methods as keys and metric values as lists.\n",
    "        metric_name (str): The name of the metric to plot.\n",
    "        client_classes (int): The number of client classes for the experiment.\n",
    "        output_dir (str): Directory to save the plot.\n",
    "    \"\"\"\n",
    "    # Set the figure size and marker size\n",
    "    fig_size = (4, 4)\n",
    "    marker_size = 4\n",
    "    font_size = 10\n",
    "    legend_font_size = 8\n",
    "    grid_alpha = 0.3  # Transparency for the grid\n",
    "\n",
    "    # Get colors from the 'tab10' color palette\n",
    "    tab10_colors = plt.get_cmap('tab10').colors\n",
    "\n",
    "    # Define method styles with colors and markers\n",
    "    method_styles = {\n",
    "        'FedAvg': {'color': tab10_colors[0], 'linestyle': '--', 'label': 'FedAvg'},  # Blue\n",
    "        'LocalTrain': {'color': tab10_colors[7], 'linestyle': '--', 'label': 'LocalTrain'},  # Grey\n",
    "    }\n",
    "\n",
    "    # Assign specific colors and markers to FedProx methods based on their lambda values\n",
    "    fedprox_methods = [method for method in metric_dict.keys() if 'FedProx' in method]\n",
    "    mu_values = []\n",
    "\n",
    "    for method in fedprox_methods:\n",
    "        match = re.search(r'_mu([0-9.]+)', method)\n",
    "        if match:\n",
    "            mu_values.append((method, float(match.group(1))))\n",
    "\n",
    "    # Sort FedProx methods by their lambda values\n",
    "    mu_values.sort(key=lambda x: x[1])\n",
    "    mu_labels = ['small lambda', 'medium lambda', 'large lambda']\n",
    "    fedprox_colors = [tab10_colors[2], tab10_colors[1], tab10_colors[3]]  # Green, Orange, Red\n",
    "\n",
    "    # Assign styles \n",
    "    # o the sorted FedProx methods\n",
    "    for (method, _), label, color in zip(mu_values, mu_labels, fedprox_colors):\n",
    "        method_styles[method] = {\n",
    "            'color': color, 'linestyle': '-', 'label': f'FedProx ({label})'\n",
    "        }\n",
    "\n",
    "    # Plot the metric as a function of communication rounds\n",
    "    plt.figure(figsize=fig_size)\n",
    "    for method, values in metric_dict.items():\n",
    "        style = method_styles.get(method, {})\n",
    "        plt.plot(range(1, len(values) + 1), values, color=style['color'], linestyle=style['linestyle'],\n",
    "                 label=style['label'], markersize=marker_size)\n",
    "\n",
    "    plt.xlabel('# Communication Rounds', fontsize=font_size)\n",
    "    plt.ylabel(metric_name if ylabel is None else ylabel, fontsize=font_size)\n",
    "\n",
    "    # Remove top and right borders\n",
    "    plt.gca().spines['top'].set_visible(False)\n",
    "    plt.gca().spines['right'].set_visible(False)\n",
    "\n",
    "    # Position the legend below the plot\n",
    "    plt.legend(fontsize=legend_font_size, loc='upper center', bbox_to_anchor=(0.5, -0.25), ncol=2)\n",
    "\n",
    "    # Add a grid with transparency\n",
    "    plt.grid(True, alpha=grid_alpha)\n",
    "\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(os.path.join(output_dir, f'{metric_name}_vs_Comm_Rounds_{client_classes}cls.png'), bbox_inches='tight')\n",
    "    plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define Total Error plot methods\n",
    "fedprox_mu_values = [0.5, 1.5, 2.5]\n",
    "methods_to_evaluate = create_methods_to_evaluate(fedprox_mu_values, fedprox_exclusive=True)\n",
    "\n",
    "# Specify the number of client classes for which you want to plot the results\n",
    "client_classes = 2\n",
    "\n",
    "# Define where to save the plots\n",
    "plotting_dir = os.path.join(result_directory, \"figures\")\n",
    "\n",
    "# For Train Losses\n",
    "metric_dict = process_metrics_in_directory(result_directory, methods_fedprox, client_classes, \"Global Train losses\")\n",
    "# Plot the specified metric over communication rounds\n",
    "plot_metric(metric_dict, \"Global Train losses\", client_classes, plotting_dir, ylabel=\"Global Training Loss\")\n",
    "\n",
    "# For Train Losses\n",
    "metric_dict = process_metrics_in_directory(result_directory, methods_fedprox, client_classes, \"Average Client Training Losses\")\n",
    "# Plot the specified metric over communication rounds\n",
    "plot_metric(metric_dict, \"Average Client Training Losses\", client_classes, plotting_dir, ylabel=\"Local Training Loss\")\n",
    "\n",
    "# For Evaluation Losses\n",
    "metric_dict = process_metrics_in_directory(result_directory, methods_fedprox, client_classes, \"Global Evaluation losses\")\n",
    "# Plot the specified metric over communication rounds\n",
    "plot_metric(metric_dict, \"Global Evaluation losses\", client_classes, plotting_dir, ylabel=\"Global Test Loss\")\n",
    "\n",
    "# For Evaluation Losses\n",
    "metric_dict = process_metrics_in_directory(result_directory, methods_fedprox, client_classes, \"Average Client Evaluation losses\")\n",
    "# Plot the specified metric over communication rounds\n",
    "plot_metric(metric_dict, \"Average Client Evaluation losses\", client_classes, plotting_dir, ylabel=\"Local Test Loss\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Local Epochs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import re\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# Function to extract a specific metric (Global Train Loss, Local Train Loss, etc.) from the file content\n",
    "def extract_metric(file_path, metric_name):\n",
    "    \"\"\"\n",
    "    Extract the specified metric from the file content.\n",
    "\n",
    "    Args:\n",
    "        file_path (str): The path to the file containing results.\n",
    "        metric_name (str): The name of the metric to extract.\n",
    "\n",
    "    Returns:\n",
    "        list: A list of values for the specified metric.\n",
    "    \"\"\"\n",
    "    with open(file_path, 'r') as f:\n",
    "        lines = f.readlines()\n",
    "\n",
    "    # Look for the line that contains the metric of interest\n",
    "    for line in lines:\n",
    "        if f\"{metric_name}:\" in line:\n",
    "            # Extract and return the list of comma-separated values as floats\n",
    "            metric_str = line.strip().replace(f\"{metric_name}: \", \"\")\n",
    "            metric_values = list(map(float, metric_str.split(',')))\n",
    "            return metric_values\n",
    "    return []\n",
    "\n",
    "# Function to process the results in the directory and store the specified metric for the selected method and E\n",
    "def process_metrics_for_method(directory, method_to_evaluate, local_update_numbers, metric_name, target_client_classes):\n",
    "    metric_dict = {}\n",
    "\n",
    "    for local_update_number in local_update_numbers:\n",
    "        subdirectory = os.path.join(directory, f\"E{local_update_number}\")\n",
    "        if os.path.exists(subdirectory):\n",
    "            for file_name in os.listdir(subdirectory):\n",
    "                if file_name.endswith(\".txt\"):\n",
    "                    # Extract method and client classes from filename (assuming consistent naming convention)\n",
    "                    method_match = re.search(r\"output_log_(.+?)_(\\d+)cls.txt\", file_name)\n",
    "                    if method_match:\n",
    "                        method = method_match.group(1)\n",
    "                        file_client_classes = int(method_match.group(2))\n",
    "                        # Only process the specified method\n",
    "                        if method != method_to_evaluate or file_client_classes != target_client_classes:\n",
    "                            continue\n",
    "\n",
    "\n",
    "                        # Full path to the file\n",
    "                        file_path = os.path.join(subdirectory, file_name)\n",
    "\n",
    "                        # Extract the metric from the file\n",
    "                        metric_values = extract_metric(file_path, metric_name)\n",
    "\n",
    "                        # Store the metric under the corresponding local update number\n",
    "                        if metric_values:\n",
    "                            metric_dict[local_update_number] = metric_values\n",
    "\n",
    "    return metric_dict\n",
    "\n",
    "# Plotting function to plot the specified metric for the selected method over different local update numbers (E values)\n",
    "def plot_metric_over_local_updates(metric_dict, metric_name, method, output_dir, ylabel=None):\n",
    "    \"\"\"\n",
    "    Plot the specified metric for the selected method over different local update numbers (E values).\n",
    "\n",
    "    Args:\n",
    "        metric_dict (dict): Dictionary with local update numbers as keys and metric values as lists.\n",
    "        metric_name (str): The name of the metric to plot.\n",
    "        method (str): The method being evaluated.\n",
    "        output_dir (str): Directory to save the plot.\n",
    "    \"\"\"\n",
    "    # Set the figure size and marker size\n",
    "    fig_size = (4, 4)\n",
    "    marker_size = 4\n",
    "    font_size = 10\n",
    "    legend_font_size = 8\n",
    "    grid_alpha = 0.3  # Transparency for the grid\n",
    "\n",
    "    # Get colors from the 'tab10' color palette\n",
    "    tab10_colors = plt.get_cmap('tab10').colors\n",
    "\n",
    "    # Plot the metric as a function of communication rounds for each E value\n",
    "    plt.figure(figsize=fig_size)\n",
    "    for local_update_number, values in metric_dict.items():\n",
    "        plt.plot(range(1, len(values) + 1), values, label=f\"# Local Epochs = {local_update_number}\", markersize=marker_size)\n",
    "\n",
    "    plt.xlabel('# Communication Rounds', fontsize=font_size)\n",
    "    plt.ylabel(metric_name if ylabel is None else ylabel, fontsize=font_size)\n",
    "\n",
    "    # Remove top and right borders\n",
    "    plt.gca().spines['top'].set_visible(False)\n",
    "    plt.gca().spines['right'].set_visible(False)\n",
    "\n",
    "    # Position the legend below the plot\n",
    "    plt.legend(fontsize=legend_font_size, loc='upper center', bbox_to_anchor=(0.5, -0.25), ncol=2)\n",
    "\n",
    "    # Add a grid with transparency\n",
    "    plt.grid(True, alpha=grid_alpha)\n",
    "\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(os.path.join(output_dir, f'{metric_name}_vs_Comm_Rounds_{method}.png'), bbox_inches='tight')\n",
    "    plt.close()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Example usage\n",
    "result_directory = f'../results/{model}' \n",
    "plotting_dir = os.path.join(result_directory, \"figures\")\n",
    "\n",
    "# Specify the method and the range of E values (local updates) to evaluate\n",
    "method_to_evaluate = \"FedProx_mu0.5\"\n",
    "local_update_numbers = [5, 10, 15, 20]  # Example set of E values\n",
    "metric_name = \"Global Train losses\"\n",
    "target_client_classes = 5\n",
    "\n",
    "# Process the result files for the specified method and E values\n",
    "metric_dict = process_metrics_for_method(result_directory, method_to_evaluate, local_update_numbers, metric_name, target_client_classes=target_client_classes)\n",
    "\n",
    "# Plot the specified metric over communication rounds for different local update numbers\n",
    "plot_metric_over_local_updates(metric_dict, metric_name, method_to_evaluate, plotting_dir, ylabel=\"Global Training Loss\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Example usage\n",
    "result_directory = f'../results/{model}' \n",
    "plotting_dir = os.path.join(result_directory, \"figures\")\n",
    "\n",
    "# Specify the method and the range of E values (local updates) to evaluate\n",
    "method_to_evaluate = \"FedProx_mu2.5\"\n",
    "local_update_numbers = [5, 10, 15, 20]  # Example set of E values\n",
    "metric_name = \"Average Client Training Losses\"\n",
    "target_client_classes = 5\n",
    "\n",
    "# Process the result files for the specified method and E values\n",
    "metric_dict = process_metrics_for_method(result_directory, method_to_evaluate, local_update_numbers, metric_name, target_client_classes=target_client_classes)\n",
    "\n",
    "# Plot the specified metric over communication rounds for different local update numbers\n",
    "plot_metric_over_local_updates(metric_dict, metric_name, method_to_evaluate, plotting_dir, ylabel=\"Local Training Loss\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "fedprox",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
