{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "f7a86983",
   "metadata": {},
   "source": [
    "# All + Time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f6f1069a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import glob\n",
    "import re\n",
    "import matplotlib.pyplot as plt\n",
    "from datetime import datetime\n",
    "import pandas as pd  # Import pandas for DataFrame creation\n",
    "\n",
    "def extract_and_plot_logs_multiple_methods(folder_paths, methods, keyword=\"resnet32\"):\n",
    "    # Create the directory where plots will be saved if it doesn't exist\n",
    "    output_dir = f'./images/{keyword}/'\n",
    "    os.makedirs(output_dir, exist_ok=True)\n",
    "\n",
    "    # Initialize dictionaries to store combined data for each method\n",
    "    combined_data = {method: {\"X\": [], \"Accuracy_CNN\": [], \"Accuracy_NME\": [], \"Forgetting_CNN\": [], \"Time_Spent\": []} for method in methods}\n",
    "\n",
    "    # Iterate over each folder and corrsesponding method\n",
    "    for folder_path, method in zip(folder_paths, methods):\n",
    "        # Use glob to search for log files containing 'resnet32' in their name\n",
    "        log_files = glob.glob(os.path.join(folder_path, f\"*{keyword}*.log\"))\n",
    "\n",
    "        # Dictionary to store log file paths with X values\n",
    "        log_files_with_x = {}\n",
    "\n",
    "        # Extract X values from filenames and store them in a dictionary\n",
    "        for log_file in log_files:\n",
    "            # Use regex to extract the X value (e.g., memory_1993_resnet32.log)\n",
    "            match = re.search(r'memory_(\\d+)_1993_{a}'.format(a=keyword), log_file)\n",
    "            if match:\n",
    "                x_value = int(match.group(1))  # Extract X as an integer\n",
    "                log_files_with_x[log_file] = x_value\n",
    "\n",
    "        # Sort the log files by X value\n",
    "        sorted_log_files = sorted(log_files_with_x.keys(), key=lambda x: log_files_with_x[x])\n",
    "\n",
    "        # Loop through each sorted log file and extract the relevant values\n",
    "        for log_file in sorted_log_files:\n",
    "            with open(log_file, 'r') as f:\n",
    "                lines = f.readlines()\n",
    "\n",
    "                # Extract the first and last timestamp for the time spent\n",
    "                first_line = lines[0]\n",
    "                last_line = lines[-1]\n",
    "                first_timestamp = re.search(r\"(\\d{4}-\\d{2}-\\d{2} \\d{2}:\\d{2}:\\d{2},\\d{3})\", first_line).group(1)\n",
    "                last_timestamp = re.search(r\"(\\d{4}-\\d{2}-\\d{2} \\d{2}:\\d{2}:\\d{2},\\d{3})\", last_line).group(1)\n",
    "                # Convert timestamps to datetime objects\n",
    "                time_format = \"%Y-%m-%d %H:%M:%S,%f\"\n",
    "                start_time = datetime.strptime(first_timestamp, time_format)\n",
    "                end_time = datetime.strptime(last_timestamp, time_format)\n",
    "                time_spent = (end_time - start_time).total_seconds() / 60  # Time in minutes\n",
    "\n",
    "                # Extract the metrics from the log\n",
    "                avg_accuracy_cnn = None\n",
    "                avg_accuracy_nme = None\n",
    "                forgetting_cnn = None\n",
    "\n",
    "                for line in lines:\n",
    "                    if \"Average Accuracy (CNN)\" in line:\n",
    "                        avg_accuracy_cnn = float(re.search(r\": ([\\d.]+)\", line).group(1))\n",
    "                    elif \"Average Accuracy (NME)\" in line:\n",
    "                        avg_accuracy_nme = float(re.search(r\": ([\\d.]+)\", line).group(1))\n",
    "                    elif \"Forgetting (CNN)\" in line:\n",
    "                        forgetting_cnn = float(re.search(r\": ([\\d.]+)\", line).group(1))\n",
    "\n",
    "                # Ensure we got all three metrics before adding them to the combined data\n",
    "                if avg_accuracy_cnn is not None and avg_accuracy_nme is not None and forgetting_cnn is not None:\n",
    "                    combined_data[method][\"X\"].append(log_files_with_x[log_file])\n",
    "                    combined_data[method][\"Accuracy_CNN\"].append(avg_accuracy_cnn)\n",
    "                    combined_data[method][\"Accuracy_NME\"].append(avg_accuracy_nme)\n",
    "                    combined_data[method][\"Forgetting_CNN\"].append(forgetting_cnn)\n",
    "                    combined_data[method][\"Time_Spent\"].append(time_spent)\n",
    "\n",
    "    # Sort the combined data by X values for each method\n",
    "    for method in methods:\n",
    "        sorted_data = sorted(zip(combined_data[method][\"X\"],\n",
    "                                 combined_data[method][\"Accuracy_CNN\"],\n",
    "                                 combined_data[method][\"Accuracy_NME\"],\n",
    "                                 combined_data[method][\"Forgetting_CNN\"],\n",
    "                                 combined_data[method][\"Time_Spent\"]),\n",
    "                             key=lambda x: x[0])\n",
    "\n",
    "        combined_data[method][\"X\"] = [x[0] for x in sorted_data]\n",
    "        combined_data[method][\"Accuracy_CNN\"] = [x[1] for x in sorted_data]\n",
    "        combined_data[method][\"Accuracy_NME\"] = [x[2] for x in sorted_data]\n",
    "        combined_data[method][\"Forgetting_CNN\"] = [x[3] for x in sorted_data]\n",
    "        combined_data[method][\"Time_Spent\"] = [x[4] for x in sorted_data]\n",
    "\n",
    "    # Define font sizes for consistency\n",
    "    title_fontsize = 24\n",
    "    label_fontsize = 20\n",
    "    legend_fontsize = 12\n",
    "\n",
    "    default_colors = plt.rcParams['axes.prop_cycle'].by_key()['color']\n",
    "    method_colors = {\n",
    "        method: default_colors[i % len(default_colors)]\n",
    "        for i, method in enumerate(methods)\n",
    "    }\n",
    "    # Now plot the combined data for all methods\n",
    "\n",
    "    # Plot Average Accuracy (CNN)\n",
    "    plt.figure(figsize=(8, 4))\n",
    "    for method in methods:\n",
    "        # Use a dashed line for 'der' and 'foster', solid for others\n",
    "        linestyle = '--' if method in ['DER','FOSTER','MEMO'] else '-'\n",
    "        plt.plot(combined_data[method][\"X\"], combined_data[method][\"Accuracy_CNN\"],\n",
    "            label=method, marker='o', linestyle=linestyle, color=method_colors[method])\n",
    "    plt.title(\"Average Accuracy (CNN)\", fontsize=title_fontsize)\n",
    "    plt.xlabel(\"Memory Size\", fontsize=label_fontsize)\n",
    "    plt.ylabel(\"Accuracy\", fontsize=label_fontsize)\n",
    "    plt.grid(True)\n",
    "    plt.legend(loc='best', fontsize=legend_fontsize, ncol=2)\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(os.path.join(output_dir, 'average_accuracy_cnn_combined.pdf'))\n",
    "    plt.savefig(os.path.join(output_dir, 'average_accuracy_cnn_combined.png'))\n",
    "    plt.show()\n",
    "\n",
    "    # Plot Average Accuracy (NME)\n",
    "    plt.figure(figsize=(8, 4))\n",
    "    for method in methods:\n",
    "        linestyle = '--' if method in ['DER','FOSTER','MEMO'] else '-'\n",
    "        plt.plot(combined_data[method][\"X\"], combined_data[method][\"Accuracy_NME\"],\n",
    "                 label=method, marker='o', linestyle=linestyle,color=method_colors[method])\n",
    "    plt.title(\"Average Accuracy (NME)\", fontsize=title_fontsize)\n",
    "    plt.xlabel(\"Memory Size\", fontsize=label_fontsize)\n",
    "    plt.ylabel(\"Accuracy\", fontsize=label_fontsize)\n",
    "    plt.grid(True)\n",
    "    plt.legend(loc='best', fontsize=legend_fontsize, ncol=2)\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(os.path.join(output_dir, 'average_accuracy_nme_combined.pdf'))\n",
    "    plt.savefig(os.path.join(output_dir, 'average_accuracy_nme_combined.png'))\n",
    "    plt.show()\n",
    "\n",
    "    # Plot Forgetting (CNN)\n",
    "    plt.figure(figsize=(8, 4))\n",
    "    for method in methods:\n",
    "        linestyle = '--' if method in ['DER','FOSTER','MEMO'] else '-'\n",
    "        plt.plot(combined_data[method][\"X\"], combined_data[method][\"Forgetting_CNN\"],\n",
    "                 label=method, marker='o', linestyle=linestyle,color=method_colors[method])\n",
    "    plt.title(\"Forgetting (CNN)\", fontsize=title_fontsize)\n",
    "    plt.xlabel(\"Memory Size\", fontsize=label_fontsize)\n",
    "    plt.ylabel(\"Forgetting (CNN)\", fontsize=label_fontsize)\n",
    "    plt.grid(True)\n",
    "    plt.legend(loc='best', fontsize=legend_fontsize, ncol=2)\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(os.path.join(output_dir, 'forgetting_cnn_combined.pdf'))\n",
    "    plt.savefig(os.path.join(output_dir, 'forgetting_cnn_combined.png'))\n",
    "    plt.show()\n",
    "\n",
    "    # Plot Time Spent\n",
    "    plt.figure(figsize=(8, 4))\n",
    "    for method in methods:\n",
    "        # Special handling for 'bic'\n",
    "        linestyle = '--' if method in ['DER','FOSTER','MEMO'] else '-'\n",
    "        plt.plot(combined_data[method][\"X\"], combined_data[method][\"Time_Spent\"],\n",
    "                 label=method, marker='*', linestyle=linestyle,color=method_colors[method])\n",
    "    plt.title(\"Time Spent\", fontsize=title_fontsize)\n",
    "    plt.xlabel(\"Memory Size\", fontsize=label_fontsize)\n",
    "    plt.ylabel(\"Time Spent (min.)\", fontsize=label_fontsize)\n",
    "    plt.grid(True)\n",
    "    plt.legend(loc='best', fontsize=legend_fontsize, ncol=2)\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(os.path.join(output_dir, 'time_spent_combined.pdf'))\n",
    "    plt.savefig(os.path.join(output_dir, 'time_spent_combined.png'))\n",
    "    plt.show()\n",
    "\n",
    "    # Create a DataFrame comparing the Average Accuracy (CNN) across methods\n",
    "    accuracy_cnn_df = pd.DataFrame()\n",
    "\n",
    "    for method in methods:\n",
    "        temp_df = pd.DataFrame({\n",
    "            'Memory Size': combined_data[method][\"X\"],\n",
    "            f'Accuracy_CNN_{method}': combined_data[method][\"Accuracy_CNN\"]\n",
    "        })\n",
    "        if accuracy_cnn_df.empty:\n",
    "            accuracy_cnn_df = temp_df\n",
    "        else:\n",
    "            accuracy_cnn_df = pd.merge(accuracy_cnn_df, temp_df, on='Memory Size', how='outer')\n",
    "\n",
    "    # Sort the DataFrame by Memory Size\n",
    "    accuracy_cnn_df = accuracy_cnn_df.sort_values(by='Memory Size').reset_index(drop=True)\n",
    "    accuracy_cnn_df.set_index('Memory Size', inplace=True)\n",
    "\n",
    "    # Save the DataFrame to a CSV file\n",
    "    accuracy_cnn_df.to_csv(os.path.join(output_dir, 'average_accuracy_cnn_comparison.csv'))\n",
    "\n",
    "    return accuracy_cnn_df  # Return the DataFrame\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bd208e9f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# List of folder paths, corresponding to each method\n",
    "folder_paths = [\n",
    "    \"./logs/der/cifar100/0/10/\",\n",
    "    \"./logs/foster/cifar100/0/10/\",\n",
    "    \"./logs/memo/cifar100/0/10/\",\n",
    "    \"./logs/replay/cifar100/0/10/\",\n",
    "    \"./logs/icarl/cifar100/0/10/\",\n",
    "    \"./logs/bic/cifar100/0/10/\",\n",
    "    \"./logs/wa/cifar100/0/10/\",\n",
    "    \"./logs/replay_ours_swa/cifar100/0/10/\",\n",
    "\n",
    "]\n",
    "\n",
    "# Corresponding methods for each folder path\n",
    "methods = [\"DER\", \"FOSTER\", \"MEMO\", \"Replay\", \"iCaRL\", \"BiC\", \"WA\",\"Ours\"]\n",
    "\n",
    "# Call the function to combine logs and plot results with method names\n",
    "extract_and_plot_logs_multiple_methods(folder_paths, methods, keyword=\"resnet32\")\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "peer",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
