{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "from collections import defaultdict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_plot_setting():\n",
    "    plt.style.use(\"seaborn-v0_8-talk\")\n",
    "    plt.rcParams.update({\n",
    "        \"axes.titlesize\": \"x-large\",\n",
    "        \"axes.labelsize\": \"xx-large\",\n",
    "    })\n",
    "    os.makedirs(\"plots\", exist_ok=True)\n",
    "\n",
    "load_plot_setting()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### MNIST dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def prepare_data():\n",
    "    diff_seeds = [1,2,3]\n",
    "    unlearn_methods = [\"retraining\", \"bmt\", \"mmt\", \"standalone\"]\n",
    "\n",
    "    # aggregate results across different seeds\n",
    "    outputs = {}\n",
    "    for method in unlearn_methods:\n",
    "        res = []\n",
    "        for seed in diff_seeds:\n",
    "            path = f\"/results/stats.json\"\n",
    "            print(f\"Loading stats from {path}\")\n",
    "            stats = json.load(open(path, \"r\"))\n",
    "            stats = pd.DataFrame(stats)     # convert to dataframe for easier aggregation\n",
    "            if method == \"standalone\":\n",
    "                stats = stats.bfill() \n",
    "            stats[\"index\"] = range(len(stats))  # create index for grouping\n",
    "            res.append(stats)\n",
    "        res = pd.concat(res).groupby(\"index\")   # group by index across different seeds\n",
    "        outputs[method] = {\"mean\": res.mean(), \"std\": res.std()}\n",
    "\n",
    "    return outputs\n",
    "\n",
    "data = prepare_data()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "import json\n",
    "\n",
    "# File paths\n",
    "log_file_path = 'files/output.log'\n",
    "json_file_path = 'plots/extracted_data.json'\n",
    "\n",
    "# Initialize lists to hold the extracted values\n",
    "train_accs = []\n",
    "train_losses = []\n",
    "test_accs = []\n",
    "test_losses = []\n",
    "\n",
    "# Regular expressions to match the desired lines\n",
    "train_acc_pattern = re.compile(r'train_acc:\\s*([\\d.]+)')\n",
    "train_loss_pattern = re.compile(r'train_loss:\\s*([\\d.]+)')\n",
    "test_acc_pattern = re.compile(r'test_acc:\\s*([\\d.]+)')\n",
    "test_loss_pattern = re.compile(r'test_loss:\\s*([\\d.]+)')\n",
    "\n",
    "# Read the log file and extract the values\n",
    "with open(log_file_path, 'r') as log_file:\n",
    "    for line in log_file:\n",
    "        train_acc_match = train_acc_pattern.search(line)\n",
    "        train_loss_match = train_loss_pattern.search(line)\n",
    "        test_acc_match = test_acc_pattern.search(line)\n",
    "        test_loss_match = test_loss_pattern.search(line)\n",
    "        \n",
    "        if train_acc_match:\n",
    "            train_accs.append(float(train_acc_match.group(1)))\n",
    "        if train_loss_match:\n",
    "            train_losses.append(float(train_loss_match.group(1)))\n",
    "        if test_acc_match:\n",
    "            test_accs.append(float(test_acc_match.group(1)))\n",
    "        if test_loss_match:\n",
    "            test_losses.append(float(test_loss_match.group(1)))\n",
    "\n",
    "# Prepare data to be saved in JSON\n",
    "extracted_data = {\n",
    "    'train_acc': train_accs,\n",
    "    'train_loss': train_losses,\n",
    "    'test_acc': test_accs,\n",
    "    'test_loss': test_losses,\n",
    "}\n",
    "\n",
    "# Save the extracted data to a JSON file\n",
    "with open(json_file_path, 'w') as json_file:\n",
    "    json.dump(extracted_data, json_file, indent=4)\n",
    "\n",
    "print(f'Extracted data saved to {json_file_path}')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import json\n",
    "\n",
    "# Specify the path to your JSON file\n",
    "file_path = 'results_local/local_training_0.json'\n",
    "\n",
    "# Open the file in read mode and load the list using json.load()\n",
    "with open(file_path, 'r') as file:\n",
    "    data = json.load(file)\n",
    "\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "data = data[:-1]\n",
    "\n",
    "# Extracting the 'loss' and 'step' values\n",
    "steps = [item['step'] for item in data]\n",
    "losses = [item['loss'] for item in data]\n",
    "\n",
    "# Creating the plot\n",
    "plt.figure(figsize=(10, 5))\n",
    "plt.plot(steps, losses, marker='o')\n",
    "plt.title('Loss vs Step')\n",
    "plt.xlabel('Step')\n",
    "plt.ylabel('Loss')\n",
    "plt.grid(True)\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "# Specify the path to your JSON file\n",
    "file_path = 'results_local/local_training.json'\n",
    "\n",
    "# Open the file in read mode and load the list using json.load()\n",
    "with open(file_path, 'r') as file:\n",
    "    list_of_dicts  = json.load(file)\n",
    "\n",
    "data = [d for d in list_of_dicts  if 'loss' in d]\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# Extracting the 'loss' and 'step' values\n",
    "steps = [item['step'] for item in data]\n",
    "losses = [item['loss'] for item in data]\n",
    "\n",
    "# Creating the plot\n",
    "plt.figure(figsize=(10, 5))\n",
    "plt.plot(losses, marker='o')\n",
    "plt.title('Loss vs Step')\n",
    "plt.xlabel('Step')\n",
    "plt.ylabel('Loss')\n",
    "plt.grid(True)\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "#gemma 2b peft\n",
    "# Specify the path to your JSON file\n",
    "file_path = 'results_local/local_training.json'\n",
    "\n",
    "# Open the file in read mode and load the list using json.load()\n",
    "with open(file_path, 'r') as file:\n",
    "    list_of_dicts  = json.load(file)\n",
    "\n",
    "data = [d for d in list_of_dicts  if 'loss' in d]\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# Extracting the 'loss' and 'step' values\n",
    "steps = [item['step'] for item in data]\n",
    "losses = [item['loss'] for item in data]\n",
    "\n",
    "# Creating the plot\n",
    "plt.figure(figsize=(10, 5))\n",
    "plt.plot(losses, marker='o')\n",
    "plt.title('Loss vs Step')\n",
    "plt.xlabel('Step')\n",
    "plt.ylabel('Loss')\n",
    "plt.grid(True)\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#gpt2 full parameter\n",
    "import json\n",
    "\n",
    "# Specify the path to your JSON file\n",
    "file_path = 'results_local/local_training.json'\n",
    "\n",
    "# Open the file in read mode and load the list using json.load()\n",
    "with open(file_path, 'r') as file:\n",
    "    list_of_dicts  = json.load(file)\n",
    "\n",
    "data = [d for d in list_of_dicts  if 'loss' in d]\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# Extracting the 'loss' and 'step' values\n",
    "steps = [item['step'] for item in data]\n",
    "losses = [item['loss'] for item in data]\n",
    "\n",
    "# Creating the plot\n",
    "plt.figure(figsize=(10, 5))\n",
    "plt.plot(losses, marker='o')\n",
    "plt.title('Loss vs Step')\n",
    "plt.xlabel('Step')\n",
    "plt.ylabel('Loss')\n",
    "plt.grid(True)\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import os\n",
    "import json\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "from collections import defaultdict\n",
    "# Path to your JSON file\n",
    "file_path = 'stats_local.json'\n",
    "\n",
    "# Open the JSON file and load the content\n",
    "with open(file_path, 'r') as file:\n",
    "    data = json.load(file)\n",
    "\n",
    "# data is now a list of dictionaries\n",
    "print(data)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_acc_values = [d['train_acc'] for d in data if 'train_acc' in d]\n",
    "#gpt2 full parameter mmt tyqiangz/multilingual-sentiments\n",
    "\n",
    "# Plotting the values\n",
    "plt.figure(figsize=(10, 5))\n",
    "plt.plot(train_acc_values, marker='o')\n",
    "plt.title('Train Accuracy Over Iterations')\n",
    "plt.xlabel('Iteration')\n",
    "plt.ylabel('Train Accuracy')\n",
    "plt.grid(True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_acc_values = [d['train_loss'] for d in data if 'train_loss' in d]\n",
    "#gpt2 full parameter mmt tyqiangz/multilingual-sentiments\n",
    "\n",
    "# Plotting the values\n",
    "plt.figure(figsize=(10, 5))\n",
    "plt.plot(train_acc_values, marker='o')\n",
    "plt.title('Train Accuracy Over Iterations')\n",
    "plt.xlabel('Iteration')\n",
    "plt.ylabel('Train Accuracy')\n",
    "plt.grid(True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_acc_values = [d['train_acc'] for d in data if 'train_acc' in d]\n",
    "#gpt2 full parameter mmt tyqiangz/multilingual-sentiments\n",
    "\n",
    "# Plotting the values\n",
    "plt.figure(figsize=(10, 5))\n",
    "plt.plot(train_acc_values, marker='o')\n",
    "plt.title('Train Accuracy Over Iterations')\n",
    "plt.xlabel('Iteration')\n",
    "plt.ylabel('Train Accuracy')\n",
    "plt.grid(True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_acc_values = [d['train_loss'] for d in data if 'train_loss' in d]\n",
    "#gpt2 full parameter mmt seara/ru_go_emotions\n",
    "\n",
    "# Plotting the values\n",
    "plt.figure(figsize=(10, 5))\n",
    "plt.plot(train_acc_values, marker='o')\n",
    "plt.title('Train Accuracy Over Iterations')\n",
    "plt.xlabel('Iteration')\n",
    "plt.ylabel('Train Accuracy')\n",
    "plt.grid(True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_acc_values = [d['train_acc'] for d in data if 'train_acc' in d]\n",
    "#gpt2 full parameter mmt seara/ru_go_emotions\n",
    "\n",
    "# Plotting the values\n",
    "plt.figure(figsize=(10, 5))\n",
    "plt.plot(train_acc_values, marker='o')\n",
    "plt.title('Train Accuracy Over Iterations')\n",
    "plt.xlabel('Iteration')\n",
    "plt.ylabel('Train Accuracy')\n",
    "plt.grid(True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_acc_values = [d['train_loss'] for d in data if 'train_loss' in d]\n",
    "#gpt2 full parameter bmt seara/ru_go_emotions\n",
    "\n",
    "# Plotting the values\n",
    "plt.figure(figsize=(10, 5))\n",
    "plt.plot(train_acc_values, marker='o')\n",
    "plt.title('Train Accuracy Over Iterations')\n",
    "plt.xlabel('Iteration')\n",
    "plt.ylabel('Train Accuracy')\n",
    "plt.grid(True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_acc_values = [d['train_acc'] for d in data if 'train_acc' in d]\n",
    "#gpt2 full parameter seara/ru_go_emotions\n",
    "\n",
    "# Plotting the values\n",
    "plt.figure(figsize=(10, 5))\n",
    "plt.plot(train_acc_values, marker='o')\n",
    "plt.title('Train Accuracy Over Iterations')\n",
    "plt.xlabel('Iteration')\n",
    "plt.ylabel('Train Accuracy')\n",
    "plt.grid(True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_acc_values = [d['train_acc'] for d in data if 'train_acc' in d]\n",
    "#gpt2 full parameter seara/ru_go_emotions\n",
    "\n",
    "# Plotting the values\n",
    "plt.figure(figsize=(10, 5))\n",
    "plt.plot(train_acc_values, marker='o')\n",
    "plt.title('Train Accuracy Over Iterations')\n",
    "plt.xlabel('Iteration')\n",
    "plt.ylabel('Train Accuracy')\n",
    "plt.grid(True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_acc_values = [d['train_loss'] for d in data if 'train_loss' in d]\n",
    "#gpt2 full parameter seara/ru_go_emotions\n",
    "\n",
    "# Plotting the values\n",
    "plt.figure(figsize=(10, 5))\n",
    "plt.plot(train_acc_values, marker='o')\n",
    "plt.title('Train Accuracy Over Iterations')\n",
    "plt.xlabel('Iteration')\n",
    "plt.ylabel('Train Accuracy')\n",
    "plt.grid(True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_acc_values = [d['train_acc'] for d in data if 'train_acc' in d]\n",
    "#gpt2 full parameter seara/ru_go_emotions\n",
    "\n",
    "# Plotting the values\n",
    "plt.figure(figsize=(10, 5))\n",
    "plt.plot(train_acc_values, marker='o')\n",
    "plt.title('Train Accuracy Over Iterations')\n",
    "plt.xlabel('Iteration')\n",
    "plt.ylabel('Train Accuracy')\n",
    "plt.grid(True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_acc_values = [d['train_loss'] for d in data if 'train_loss' in d]\n",
    "#gpt2 full parameter seara/ru_go_emotions\n",
    "\n",
    "# Plotting the values\n",
    "plt.figure(figsize=(10, 5))\n",
    "plt.plot(train_acc_values, marker='o')\n",
    "plt.title('Train Accuracy Over Iterations')\n",
    "plt.xlabel('Iteration')\n",
    "plt.ylabel('Train Accuracy')\n",
    "plt.grid(True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import os\n",
    "import json\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "from collections import defaultdict\n",
    "# Path to your JSON file\n",
    "file_path = 'results/stats.json'\n",
    "\n",
    "# Open the JSON file and load the content\n",
    "with open(file_path, 'r') as file:\n",
    "    data = json.load(file)\n",
    "\n",
    "# data is now a list of dictionaries\n",
    "print(data)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_acc_values = [d['train_acc'] for d in data if 'train_acc' in d]\n",
    "\n",
    "# Plotting the values\n",
    "plt.figure(figsize=(10, 5))\n",
    "plt.plot(train_acc_values, marker='o')\n",
    "plt.title('Train Accuracy Over Iterations')\n",
    "plt.xlabel('Iteration')\n",
    "plt.ylabel('Train Accuracy')\n",
    "plt.grid(True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_acc_values = [d['train_acc'] for d in data if 'train_acc' in d]\n",
    "\n",
    "# Plotting the values\n",
    "plt.figure(figsize=(10, 5))\n",
    "plt.plot(train_acc_values, marker='o')\n",
    "plt.title('Train Accuracy Over Iterations')\n",
    "plt.xlabel('Iteration')\n",
    "plt.ylabel('Train Accuracy')\n",
    "plt.grid(True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#gpt2 full parameter results_retraining_language-identification_10steps 8 clients\n",
    "\n",
    "train_acc_values = [d['test_acc'] for d in data if 'test_acc' in d]\n",
    "\n",
    "# Plotting the values\n",
    "plt.figure(figsize=(10, 5))\n",
    "plt.plot(train_acc_values, marker='o')\n",
    "plt.title('Train Accuracy Over Iterations')\n",
    "plt.xlabel('Iteration')\n",
    "plt.ylabel('Train Accuracy')\n",
    "plt.grid(True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#gpt2 full parameter results_bmt_language-identification_10steps 8 clients\n",
    "\n",
    "train_acc_values = [d['test_acc'] for d in data if 'test_acc' in d]\n",
    "\n",
    "# Plotting the values\n",
    "plt.figure(figsize=(10, 5))\n",
    "plt.plot(train_acc_values, marker='o')\n",
    "plt.title('Train Accuracy Over Iterations')\n",
    "plt.xlabel('Iteration')\n",
    "plt.ylabel('Train Accuracy')\n",
    "plt.grid(True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#gpt2 full parameter results_mmt_language-identification_10steps 8 clients\n",
    "\n",
    "train_acc_values = [d['test_acc'] for d in data if 'test_acc' in d]\n",
    "\n",
    "# Plotting the values\n",
    "plt.figure(figsize=(10, 5))\n",
    "plt.plot(train_acc_values, marker='o')\n",
    "plt.title('Train Accuracy Over Iterations')\n",
    "plt.xlabel('Iteration')\n",
    "plt.ylabel('Train Accuracy')\n",
    "plt.grid(True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#gpt2 full parameter results_retraining_ru_go_emotions_100steps 8 clients\n",
    "\n",
    "train_acc_values = [d['test_acc'] for d in data if 'test_acc' in d]\n",
    "\n",
    "# Plotting the values\n",
    "plt.figure(figsize=(10, 5))\n",
    "plt.plot(train_acc_values, marker='o')\n",
    "plt.title('Train Accuracy Over Iterations')\n",
    "plt.xlabel('Iteration')\n",
    "plt.ylabel('Train Accuracy')\n",
    "plt.grid(True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#gpt2 full parameter results_mmt_ru_go_emotions_100steps 8 clients\n",
    "\n",
    "train_acc_values = [d['test_acc'] for d in data if 'test_acc' in d]\n",
    "\n",
    "# Plotting the values\n",
    "plt.figure(figsize=(10, 5))\n",
    "plt.plot(train_acc_values, marker='o')\n",
    "plt.title('Train Accuracy Over Iterations')\n",
    "plt.xlabel('Iteration')\n",
    "plt.ylabel('Train Accuracy')\n",
    "plt.grid(True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#gpt2 full parameter results_bmt_ru_go_emotions_100steps 8 clients\n",
    "\n",
    "train_acc_values = [d['test_acc'] for d in data if 'test_acc' in d]\n",
    "\n",
    "# Plotting the values\n",
    "plt.figure(figsize=(10, 5))\n",
    "plt.plot(train_acc_values, marker='o')\n",
    "plt.title('Train Accuracy Over Iterations')\n",
    "plt.xlabel('Iteration')\n",
    "plt.ylabel('Train Accuracy')\n",
    "plt.grid(True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#gpt2 full parameter mmt seara/ru_go_emotions 8 clients\n",
    "\n",
    "train_acc_values = [d['train_acc'] for d in data if 'train_acc' in d]\n",
    "\n",
    "# Plotting the values\n",
    "plt.figure(figsize=(10, 5))\n",
    "plt.plot(train_acc_values, marker='o')\n",
    "plt.title('Train Accuracy Over Iterations')\n",
    "plt.xlabel('Iteration')\n",
    "plt.ylabel('Train Accuracy')\n",
    "plt.grid(True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#gpt2 full parameter bmt seara/ru_go_emotions 8 clients\n",
    "\n",
    "train_acc_values = [d['train_acc'] for d in data if 'train_acc' in d]\n",
    "\n",
    "# Plotting the values\n",
    "plt.figure(figsize=(10, 5))\n",
    "plt.plot(train_acc_values, marker='o')\n",
    "plt.title('Train Accuracy Over Iterations')\n",
    "plt.xlabel('Iteration')\n",
    "plt.ylabel('Train Accuracy')\n",
    "plt.grid(True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#gpt2 full parameter bmt papluca/language-identification 8 clients\n",
    "\n",
    "train_acc_values = [d['train_acc'] for d in data if 'train_acc' in d]\n",
    "\n",
    "# Plotting the values\n",
    "plt.figure(figsize=(10, 5))\n",
    "plt.plot(train_acc_values, marker='o')\n",
    "plt.title('Train Accuracy Over Iterations')\n",
    "plt.xlabel('Iteration')\n",
    "plt.ylabel('Train Accuracy')\n",
    "plt.grid(True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#gpt2 full parameter mmt papluca/language-identification 8 clients\n",
    "\n",
    "train_acc_values = [d['train_acc'] for d in data if 'train_acc' in d]\n",
    "\n",
    "# Plotting the values\n",
    "plt.figure(figsize=(10, 5))\n",
    "plt.plot(train_acc_values, marker='o')\n",
    "plt.title('Train Accuracy Over Iterations')\n",
    "plt.xlabel('Iteration')\n",
    "plt.ylabel('Train Accuracy')\n",
    "plt.grid(True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#gpt2 full parameter mmt tyqiangz/multilingual-sentiments 8 clients\n",
    "\n",
    "train_acc_values = [d['train_acc'] for d in data if 'train_acc' in d]\n",
    "\n",
    "# Plotting the values\n",
    "plt.figure(figsize=(10, 5))\n",
    "plt.plot(train_acc_values, marker='o')\n",
    "plt.title('Train Accuracy Over Iterations')\n",
    "plt.xlabel('Iteration')\n",
    "plt.ylabel('Train Accuracy')\n",
    "plt.grid(True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#gpt2 full parameter bmt tyqiangz/multilingual-sentiments 8 clients\n",
    "\n",
    "train_acc_values = [d['train_acc'] for d in data if 'train_acc' in d]\n",
    "\n",
    "# Plotting the values\n",
    "plt.figure(figsize=(10, 5))\n",
    "plt.plot(train_acc_values, marker='o')\n",
    "plt.title('Train Accuracy Over Iterations')\n",
    "plt.xlabel('Iteration')\n",
    "plt.ylabel('Train Accuracy')\n",
    "plt.grid(True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#gpt2 full parameter bmt tyqiangz/multilingual-sentiments 7 clients\n",
    "\n",
    "train_acc_values = [d['train_acc'] for d in data if 'train_acc' in d]\n",
    "\n",
    "# Plotting the values\n",
    "plt.figure(figsize=(10, 5))\n",
    "plt.plot(train_acc_values, marker='o')\n",
    "plt.title('Train Accuracy Over Iterations')\n",
    "plt.xlabel('Iteration')\n",
    "plt.ylabel('Train Accuracy')\n",
    "plt.grid(True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#gpt2 full parameter bmt tyqiangz/multilingual-sentiments 6 clients\n",
    "\n",
    "train_acc_values = [d['train_acc'] for d in data if 'train_acc' in d]\n",
    "\n",
    "# Plotting the values\n",
    "plt.figure(figsize=(10, 5))\n",
    "plt.plot(train_acc_values, marker='o')\n",
    "plt.title('Train Accuracy Over Iterations')\n",
    "plt.xlabel('Iteration')\n",
    "plt.ylabel('Train Accuracy')\n",
    "plt.grid(True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#gpt2 full parameter bmt tyqiangz/multilingual-sentiments 5 clients\n",
    "\n",
    "train_acc_values = [d['train_acc'] for d in data if 'train_acc' in d]\n",
    "\n",
    "# Plotting the values\n",
    "plt.figure(figsize=(10, 5))\n",
    "plt.plot(train_acc_values, marker='o')\n",
    "plt.title('Train Accuracy Over Iterations')\n",
    "plt.xlabel('Iteration')\n",
    "plt.ylabel('Train Accuracy')\n",
    "plt.grid(True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#gpt2 full parameter bmt tyqiangz/multilingual-sentiments 4 clients\n",
    "\n",
    "train_acc_values = [d['train_acc'] for d in data if 'train_acc' in d]\n",
    "\n",
    "# Plotting the values\n",
    "plt.figure(figsize=(10, 5))\n",
    "plt.plot(train_acc_values, marker='o')\n",
    "plt.title('Train Accuracy Over Iterations')\n",
    "plt.xlabel('Iteration')\n",
    "plt.ylabel('Train Accuracy')\n",
    "plt.grid(True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#gpt2 full parameter bmt tyqiangz/multilingual-sentiments 3 clients\n",
    "\n",
    "train_acc_values = [d['train_acc'] for d in data if 'train_acc' in d]\n",
    "\n",
    "# Plotting the values\n",
    "plt.figure(figsize=(10, 5))\n",
    "plt.plot(train_acc_values, marker='o')\n",
    "plt.title('Train Accuracy Over Iterations')\n",
    "plt.xlabel('Iteration')\n",
    "plt.ylabel('Train Accuracy')\n",
    "plt.grid(True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#gpt2 full parameter mmt tyqiangz/multilingual-sentiments\n",
    "\n",
    "train_acc_values = [d['train_acc'] for d in data if 'train_acc' in d]\n",
    "\n",
    "# Plotting the values\n",
    "plt.figure(figsize=(10, 5))\n",
    "plt.plot(train_acc_values, marker='o')\n",
    "plt.title('Train Accuracy Over Iterations')\n",
    "plt.xlabel('Iteration')\n",
    "plt.ylabel('Train Accuracy')\n",
    "plt.grid(True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#gpt2 full parameter mmt tyqiangz/multilingual-sentiments\n",
    "\n",
    "train_acc_values = [d['train_acc'] for d in data if 'train_acc' in d]\n",
    "\n",
    "# Plotting the values\n",
    "plt.figure(figsize=(10, 5))\n",
    "plt.plot(train_acc_values, marker='o')\n",
    "plt.title('Train Accuracy Over Iterations')\n",
    "plt.xlabel('Iteration')\n",
    "plt.ylabel('Train Accuracy')\n",
    "plt.grid(True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#gpt2 full parameter mmt tyqiangz/multilingual-sentiments\n",
    "\n",
    "train_acc_values = [d['test_acc'] for d in data if 'test_acc' in d]\n",
    "\n",
    "# Plotting the values\n",
    "plt.figure(figsize=(10, 5))\n",
    "plt.plot(train_acc_values, marker='o')\n",
    "plt.title('Train Accuracy Over Iterations')\n",
    "plt.xlabel('Iteration')\n",
    "plt.ylabel('Train Accuracy')\n",
    "plt.grid(True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#gpt2 full parameter mmt tyqiangz/multilingual-sentiments\n",
    "\n",
    "train_acc_values = [d['test_acc'] for d in data if 'test_acc' in d]\n",
    "\n",
    "# Plotting the values\n",
    "plt.figure(figsize=(10, 5))\n",
    "plt.plot(train_acc_values, marker='o')\n",
    "plt.title('Train Accuracy Over Iterations')\n",
    "plt.xlabel('Iteration')\n",
    "plt.ylabel('Train Accuracy')\n",
    "plt.grid(True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#gpt2 full parameter mmt tyqiangz/multilingual-sentiments\n",
    "\n",
    "train_acc_values = [d['train_acc'] for d in data if 'train_acc' in d]\n",
    "\n",
    "# Plotting the values\n",
    "plt.figure(figsize=(10, 5))\n",
    "plt.plot(train_acc_values, marker='o')\n",
    "plt.title('Train Accuracy Over Iterations')\n",
    "plt.xlabel('Iteration')\n",
    "plt.ylabel('Train Accuracy')\n",
    "plt.grid(True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#gpt2 full parameter mmt seara/ru_go_emotions\n",
    "\n",
    "train_acc_values = [d['test_acc'] for d in data if 'test_acc' in d]\n",
    "\n",
    "# Plotting the values\n",
    "plt.figure(figsize=(10, 5))\n",
    "plt.plot(train_acc_values, marker='o')\n",
    "plt.title('Train Accuracy Over Iterations')\n",
    "plt.xlabel('Iteration')\n",
    "plt.ylabel('Train Accuracy')\n",
    "plt.grid(True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#gpt2 full parameter mmt\n",
    "\n",
    "train_acc_values = [d['train_acc'] for d in data if 'train_acc' in d]\n",
    "\n",
    "# Plotting the values\n",
    "plt.figure(figsize=(10, 5))\n",
    "plt.plot(train_acc_values, marker='o')\n",
    "plt.title('Train Accuracy Over Iterations')\n",
    "plt.xlabel('Iteration')\n",
    "plt.ylabel('Train Accuracy')\n",
    "plt.grid(True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#gpt2 full parameter bmt\n",
    "\n",
    "train_acc_values = [d['test_acc'] for d in data if 'test_acc' in d]\n",
    "\n",
    "# Plotting the values\n",
    "plt.figure(figsize=(10, 5))\n",
    "plt.plot(train_acc_values, marker='o')\n",
    "plt.title('Train Accuracy Over Iterations')\n",
    "plt.xlabel('Iteration')\n",
    "plt.ylabel('Train Accuracy')\n",
    "plt.grid(True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#gpt2 full parameter bmt\n",
    "\n",
    "train_acc_values = [d['train_acc'] for d in data if 'train_acc' in d]\n",
    "\n",
    "# Plotting the values\n",
    "plt.figure(figsize=(10, 5))\n",
    "plt.plot(train_acc_values, marker='o')\n",
    "plt.title('Train Accuracy Over Iterations')\n",
    "plt.xlabel('Iteration')\n",
    "plt.ylabel('Train Accuracy')\n",
    "plt.grid(True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#gemma 2b #gpt2 full parameter\n",
    "\n",
    "\n",
    "train_acc_values = [d['train_loss'] for d in data if 'train_loss' in d]\n",
    "\n",
    "# Plotting the values\n",
    "plt.figure(figsize=(10, 5))\n",
    "plt.plot(train_acc_values, marker='o')\n",
    "plt.title('Train Accuracy Over Iterations')\n",
    "plt.xlabel('Iteration')\n",
    "plt.ylabel('Train Accuracy')\n",
    "plt.grid(True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#gpt2 full parameter\n",
    "\n",
    "train_acc_values = [d['train_acc'] for d in data if 'train_acc' in d]\n",
    "\n",
    "# Plotting the values\n",
    "plt.figure(figsize=(10, 5))\n",
    "plt.plot(train_acc_values, marker='o')\n",
    "plt.title('Train Accuracy Over Iterations')\n",
    "plt.xlabel('Iteration')\n",
    "plt.ylabel('Train Accuracy')\n",
    "plt.grid(True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#gpt2 full parameter\n",
    "\n",
    "train_acc_values = [d['train_acc'] for d in data if 'train_acc' in d]\n",
    "\n",
    "# Plotting the values\n",
    "plt.figure(figsize=(10, 5))\n",
    "plt.plot(train_acc_values, marker='o')\n",
    "plt.title('Train Accuracy Over Iterations')\n",
    "plt.xlabel('Iteration')\n",
    "plt.ylabel('Train Accuracy')\n",
    "plt.grid(True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#gemma 2b #gpt2 full parameter\n",
    "\n",
    "\n",
    "train_acc_values = [d['train_loss'] for d in data if 'train_loss' in d]\n",
    "\n",
    "# Plotting the values\n",
    "plt.figure(figsize=(10, 5))\n",
    "plt.plot(train_acc_values, marker='o')\n",
    "plt.title('Train Accuracy Over Iterations')\n",
    "plt.xlabel('Iteration')\n",
    "plt.ylabel('Train Accuracy')\n",
    "plt.grid(True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#gpt2 full parameter\n",
    "\n",
    "train_acc_values = [d['train_acc'] for d in data if 'train_acc' in d]\n",
    "\n",
    "# Plotting the values\n",
    "plt.figure(figsize=(10, 5))\n",
    "plt.plot(train_acc_values, marker='o')\n",
    "plt.title('Train Accuracy Over Iterations')\n",
    "plt.xlabel('Iteration')\n",
    "plt.ylabel('Train Accuracy')\n",
    "plt.grid(True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#gemma 2b #gpt2 full parameter\n",
    "\n",
    "\n",
    "train_acc_values = [d['train_loss'] for d in data if 'train_loss' in d]\n",
    "\n",
    "# Plotting the values\n",
    "plt.figure(figsize=(10, 5))\n",
    "plt.plot(train_acc_values, marker='o')\n",
    "plt.title('Train Accuracy Over Iterations')\n",
    "plt.xlabel('Iteration')\n",
    "plt.ylabel('Train Accuracy')\n",
    "plt.grid(True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#gemma 2b #gpt2 full parameter\n",
    "\n",
    "\n",
    "train_acc_values = [d['train_loss'] for d in data if 'train_loss' in d]\n",
    "\n",
    "# Plotting the values\n",
    "plt.figure(figsize=(10, 5))\n",
    "plt.plot(train_acc_values, marker='o')\n",
    "plt.title('Train Accuracy Over Iterations')\n",
    "plt.xlabel('Iteration')\n",
    "plt.ylabel('Train Accuracy')\n",
    "plt.grid(True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#gemma 2b peft\n",
    "\n",
    "train_acc_values = [d['train_acc'] for d in data if 'train_acc' in d]\n",
    "\n",
    "# Plotting the values\n",
    "plt.figure(figsize=(10, 5))\n",
    "plt.plot(train_acc_values, marker='o')\n",
    "plt.title('Train Accuracy Over Iterations')\n",
    "plt.xlabel('Iteration')\n",
    "plt.ylabel('Train Accuracy')\n",
    "plt.grid(True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot(ax, mean, std, label, lw=3, ls='-'):\n",
    "    ax.plot(mean, label=label, lw=lw, ls=ls)\n",
    "    ax.fill_between(range(len(mean)), mean-std, mean+std, alpha=.5)\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(5.5,3))\n",
    "# plot(ax, data[\"sisa\"][\"mean\"][\"test_acc\"], data[\"sisa\"][\"std\"][\"test_acc\"], label=\"SISA\");\n",
    "plot(ax, data[\"retraining\"][\"mean\"][\"test_acc\"], data[\"retraining\"][\"std\"][\"test_acc\"], label=\"RfS\");\n",
    "plot(ax, data[\"bmt\"][\"mean\"][\"test_acc\"], data[\"bmt\"][\"std\"][\"test_acc\"], label=\"BMT\");\n",
    "plot(ax, data[\"mmt\"][\"mean\"][\"test_acc\"], data[\"mmt\"][\"std\"][\"test_acc\"], label=\"MMT\");\n",
    "plot(ax, data[\"standalone\"][\"mean\"][\"test_acc\"], data[\"standalone\"][\"std\"][\"test_acc\"], label=\"Standalone\", ls='dashed')\n",
    "ax.legend(loc=\"lower left\", ncol=2);\n",
    "ax.set_xlabel(\"No. Rounds\");\n",
    "ax.set_ylabel(\"Test Accuracy (%)\");\n",
    "ax.set_ylim(50, 95);\n",
    "figure_name = \"plots/mnist_sequential.png\"\n",
    "plt.savefig(figure_name, bbox_inches=\"tight\", dpi=200)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### FashionMNIST dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def prepare_data():\n",
    "    diff_seeds = [1,2,3]\n",
    "    unlearn_methods = [\"retraining\", \"bmt\", \"mmt\", \"standalone\"]\n",
    "\n",
    "    # aggregate results across different seeds\n",
    "    outputs = {}\n",
    "    for method in unlearn_methods:\n",
    "        res = []\n",
    "        for seed in diff_seeds:\n",
    "            path = f\"../sequential_unlearning/results/fashion_mnist_parallel/seed-{seed}/{method}/stats.json\"\n",
    "            print(f\"Loading stats from {path}\")\n",
    "            stats = json.load(open(path, \"r\"))\n",
    "            stats = pd.DataFrame(stats)     # convert to dataframe for easier aggregation\n",
    "            if \"standalone\" in method:\n",
    "                stats = stats.bfill() \n",
    "            stats[\"index\"] = range(len(stats))  # create index for grouping\n",
    "            res.append(stats)\n",
    "        res = pd.concat(res).groupby(\"index\")   # group by index across different seeds\n",
    "        outputs[method] = {\"mean\": res.mean(), \"std\": res.std()}\n",
    "\n",
    "    return outputs\n",
    "\n",
    "data = prepare_data()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot(ax, mean, std, label, lw=3, ls='-'):\n",
    "    ax.plot(mean, label=label, lw=lw, ls=ls)\n",
    "    ax.fill_between(range(len(mean)), mean-std, mean+std, alpha=.5)\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(5.5,3))\n",
    "metrics = \"test_acc\"\n",
    "# plot(ax, data[\"sisa\"][\"mean\"][\"test_acc\"], data[\"sisa\"][\"std\"][\"test_acc\"], label=\"SISA\");\n",
    "plot(ax, data[\"retraining\"][\"mean\"][metrics], data[\"retraining\"][\"std\"][metrics], label=\"RfS\");\n",
    "plot(ax, data[\"bmt\"][\"mean\"][metrics], data[\"bmt\"][\"std\"][metrics], label=\"BMT\");\n",
    "plot(ax, data[\"mmt\"][\"mean\"][metrics], data[\"mmt\"][\"std\"][metrics], label=\"MMT\");\n",
    "plot(ax, data[\"standalone\"][\"mean\"][\"test_acc\"], data[\"standalone\"][\"std\"][\"test_acc\"], label=\"Standalone\", ls='dashed')\n",
    "ax.legend(loc=\"lower left\", ncol=2);\n",
    "ax.set_xlabel(\"No. Rounds\");\n",
    "ax.set_ylabel(\"Test Accuracy (%)\");\n",
    "ax.set_ylim(40, 85);\n",
    "figure_name = \"plots/fashion_mnist_sequential.png\"\n",
    "plt.savefig(figure_name, bbox_inches=\"tight\", dpi=200)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### CIFAR10 dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def prepare_data():\n",
    "    diff_seeds = [1,2,3]\n",
    "    unlearn_methods = [\"retraining\", \"bmt\", \"mmt\", \"standalone\"]\n",
    "\n",
    "    # aggregate results across different seeds\n",
    "    outputs = {}\n",
    "    for method in unlearn_methods:\n",
    "        res = []\n",
    "        for seed in diff_seeds:\n",
    "            path = f\"../sequential_unlearning/results/cifar10_parallel_longer_2/seed-{seed}/{method}/stats.json\"\n",
    "            print(f\"Loading stats from {path}\")\n",
    "            stats = json.load(open(path, \"r\"))\n",
    "            stats = pd.DataFrame(stats)     # convert to dataframe for easier aggregation\n",
    "            if \"standalone\" in method:\n",
    "                stats = stats.bfill() \n",
    "            stats[\"index\"] = range(len(stats))  # create index for grouping\n",
    "            res.append(stats)\n",
    "        res = pd.concat(res).groupby(\"index\")   # group by index across different seeds\n",
    "        outputs[method] = {\"mean\": res.mean(), \"std\": res.std()}\n",
    "\n",
    "    return outputs\n",
    "\n",
    "data = prepare_data()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot(ax, mean, std, label, lw=2, ls='-'):\n",
    "    ax.plot(mean, label=label, lw=lw, ls=ls)\n",
    "    ax.fill_between(range(len(mean)), mean-std, mean+std, alpha=.5)\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(5.5,3))\n",
    "metrics = \"test_acc\"\n",
    "plot(ax, data[\"retraining\"][\"mean\"][metrics], data[\"retraining\"][\"std\"][metrics], label=\"RfS\");\n",
    "plot(ax, data[\"bmt\"][\"mean\"][metrics], data[\"bmt\"][\"std\"][metrics], label=\"BMT\");\n",
    "plot(ax, data[\"mmt\"][\"mean\"][metrics], data[\"mmt\"][\"std\"][metrics], label=\"MMT\");\n",
    "plot(ax, data[\"standalone\"][\"mean\"][\"test_acc\"], data[\"standalone\"][\"std\"][\"test_acc\"], label=\"Standalone\", ls='dashed')\n",
    "# plot(ax, data[\"sisa\"][\"mean\"][metrics], data[\"sisa\"][\"std\"][metrics], label=\"SISA\");\n",
    "ax.legend(loc=\"lower left\", ncol=2);\n",
    "ax.set_xlabel(\"No. Rounds\");\n",
    "ax.set_ylabel(\"Test Accuracy (%)\");\n",
    "ax.set_ylim(0, 45);\n",
    "figure_name = \"plots/cifar10_sequential.png\"\n",
    "plt.savefig(figure_name, bbox_inches=\"tight\", dpi=200)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "###  CIFAR100 dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def prepare_data():\n",
    "    diff_seeds = [1,2,3]\n",
    "    unlearn_methods = [\"retraining\", \"mmt\", \"bmt\", \"standalone\"]\n",
    "\n",
    "    # aggregate results across different seeds\n",
    "    outputs = {}\n",
    "    for method in unlearn_methods:\n",
    "        res = []\n",
    "        for seed in diff_seeds:\n",
    "            path = f\"../sequential_unlearning/results/cifar100_parallel_more-class_longer/seed-{seed}/{method}/stats.json\"\n",
    "            print(f\"Loading stats from {path}\")\n",
    "            stats = json.load(open(path, \"r\"))\n",
    "            stats = pd.DataFrame(stats)     # convert to dataframe for easier aggregation\n",
    "            if \"standalone\" in method:\n",
    "                stats = stats.bfill() \n",
    "            stats[\"index\"] = range(len(stats))  # create index for grouping\n",
    "            res.append(stats)\n",
    "        res = pd.concat(res).groupby(\"index\")   # group by index across different seeds\n",
    "        outputs[method] = {\"mean\": res.mean(), \"std\": res.std()}\n",
    "\n",
    "    return outputs\n",
    "\n",
    "data = prepare_data()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot(ax, mean, std, label, lw=2, ls='-'):\n",
    "    ax.plot(mean, label=label, lw=lw, ls=ls)\n",
    "    ax.fill_between(range(len(mean)), mean-std, mean+std, alpha=.5)\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(5.5,3))\n",
    "metrics = \"test_acc\"\n",
    "plot(ax, data[\"retraining\"][\"mean\"][metrics], data[\"retraining\"][\"std\"][metrics], label=\"RfS\");\n",
    "plot(ax, data[\"bmt\"][\"mean\"][metrics], data[\"bmt\"][\"std\"][metrics], label=\"BMT\");\n",
    "plot(ax, data[\"mmt\"][\"mean\"][metrics], data[\"mmt\"][\"std\"][metrics], label=\"MMT\");\n",
    "plot(ax, data[\"standalone\"][\"mean\"][\"test_acc\"], data[\"standalone\"][\"std\"][\"test_acc\"], label=\"Standalone\", ls='dashed')\n",
    "# plot(ax, data[\"sisa\"][\"mean\"][metrics], data[\"sisa\"][\"std\"][metrics], label=\"SISA\");\n",
    "ax.legend(loc=\"lower left\", ncol=2);\n",
    "ax.set_xlabel(\"No. Rounds\");\n",
    "ax.set_ylabel(\"Test Accuracy (%)\");\n",
    "figure_name = \"plots/cifar100_sequential.png\"\n",
    "plt.savefig(figure_name, bbox_inches=\"tight\", dpi=200)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "unlearning",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
