{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "f1164884",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "import yaml\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from pathlib import Path\n",
    "\n",
    "pd.set_option('display.max_columns', None)\n",
    "pd.set_option('display.width', None)\n",
    "plt.style.use('seaborn-v0_8-whitegrid')\n",
    "sns.set_palette('husl')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "c9c2a14a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_sweep_results_with_hyperparams(sweep_ids: list, wandb_dir: Path) -> pd.DataFrame:\n",
    "    \"\"\"Load all run results including hyperparameters from multiple sweeps.\"\"\"\n",
    "    all_results = []\n",
    "    \n",
    "    for sweep_id in sweep_ids:\n",
    "        sweep_dir = wandb_dir / f\"sweep-{sweep_id}\"\n",
    "        if not sweep_dir.exists():\n",
    "            print(f\"Warning: Sweep directory not found for {sweep_id}\")\n",
    "            continue\n",
    "            \n",
    "        print(f\"Loading sweep: {sweep_id}\")\n",
    "        \n",
    "        # Get all run IDs from sweep config files\n",
    "        for config_file in sweep_dir.glob(\"config-*.yaml\"):\n",
    "            run_id = config_file.stem.replace(\"config-\", \"\")\n",
    "            \n",
    "            # Find the corresponding run directory\n",
    "            run_dirs = list(wandb_dir.glob(f\"run-*-{run_id}\"))\n",
    "            if not run_dirs:\n",
    "                continue\n",
    "            \n",
    "            run_dir = run_dirs[0]\n",
    "            \n",
    "            # Load config from sweep directory\n",
    "            with open(config_file) as f:\n",
    "                config = yaml.safe_load(f)\n",
    "            \n",
    "            # Load summary from run directory\n",
    "            summary_file = run_dir / \"files\" / \"wandb-summary.json\"\n",
    "            if not summary_file.exists():\n",
    "                continue\n",
    "                \n",
    "            with open(summary_file) as f:\n",
    "                summary = json.load(f)\n",
    "            \n",
    "            # Helper to extract value from config\n",
    "            def get_config_value(key):\n",
    "                val = config.get(key, {})\n",
    "                if isinstance(val, dict):\n",
    "                    return val.get(\"value\")\n",
    "                return val\n",
    "            \n",
    "            # Extract all relevant info including hyperparameters\n",
    "            result = {\n",
    "                \"sweep_id\": sweep_id,\n",
    "                \"run_id\": run_id,\n",
    "                \"setting\": get_config_value(\"setting\"),\n",
    "                \"method\": get_config_value(\"method\"),\n",
    "                \"seed\": get_config_value(\"seed\"),\n",
    "                \"lr\": get_config_value(\"lr\"),\n",
    "                \"beta_efc\": get_config_value(\"beta_efc\"),\n",
    "                \"target_lr\": get_config_value(\"target_lr\"),\n",
    "                \"batch_size\": get_config_value(\"batch_size\"),\n",
    "                \"epochs\": get_config_value(\"epochs\"),\n",
    "                \"final_avg_accuracy\": summary.get(\"final_avg_accuracy\"),\n",
    "                \"final_forgetting\": summary.get(\"final_forgetting\"),\n",
    "                \"layer_size\": get_config_value(\"layer_size\"),\n",
    "                \"cnn_pretrained\": get_config_value(\"cnn_pretrained\")\n",
    "            }\n",
    "            all_results.append(result)\n",
    "    \n",
    "    return pd.DataFrame(all_results)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0a517701",
   "metadata": {},
   "source": [
    "# Load sweeps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "d8421c75",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'bp_taskIL_MNIST': 'fxyfl702', 'bp_taskIL_MNIST_Encoded': 'dwy7u31i', 'bp_taskIL_CIFAR10': 'uxwfc030', 'bp_taskIL_TinyImageNet': 'lnxoup5r', 'bp_classIL_MNIST_Encoded': 'na9m91mn', 'bp_classIL_MNIST': 'l3igjcsf', 'bp_classIL_TinyImageNet': '4w31ji4p', 'bp_classIL_CIFAR5Task': 'jt0sbq84', 'ewc_taskIL_MNIST_Encoded': '624uewad', 'ewc_taskIL_MNIST': 'ynthaqge', 'ewc_taskIL_CIFAR10': 'pv9fmucf', 'ewc_taskIL_TinyImageNet': 'z80oh9jt', 'ewc_classIL_MNIST_Encoded': 'lmndtcgs', 'ewc_classIL_TinyImageNet': '9vtbd28v', 'ewc_classIL_CIFAR5Task': '2yy94mkk', 'oewc_taskIL_MNIST': '5li65lf9', 'oewc_taskIL_MNIST_Encoded': 'haxoiwpy', 'oewc_taskIL_CIFAR10': 'k4ty6g4y', 'oewc_taskIL_TinyImageNet': '1feg452a', 'oewc_classIL_MNIST': '96rtma30', 'oewc_classIL_MNIST_Encoded': '0uwwn7pg', 'oewc_classIL_CIFAR5Task': '5orh3y5r', 'oewc_classIL_TinyImageNet': 'aqqn6apn', 'si_taskIL_MNIST': 'lx9cnj04', 'si_taskIL_MNIST_Encoded': 't814a2bw', 'si_taskIL_TinyImageNet': 'w4tyfd49', 'si_taskIL_CIFAR10': 'yfnqasl9', 'si_classIL_MNIST': '0hkmme2n', 'si_classIL_MNIST_Encoded': 'su8091cg', 'si_classIL_TinyImageNet': '6upykwgh', 'si_classIL_CIFAR5Task': '8iacqh0q', 'taskIL_MNIST_EFC': 'sboxxxxq', 'taskIL_EFC_MNIST_Encoded': 'v6nmf69q', 'taskIL_EFC_CIFAR10': 'tdjyv5su', 'taskIL_EFC_TinyImageNet': 'i1ejkeu0', 'ClassIL_EFC_MNIST': '9nlzz1az', 'ClassIL_EFC_MNIST_Encoded': 'k4by5mb7', 'derpp_taskIL_MNIST': 'hvjx1v04', 'derpp_taskIL_MNIST_Encoded': 'wix5sln6', 'derpp_taskIL_CIFAR10': '2rpifx70', 'derpp_taskIL_TinyImageNet': '0xu65xiy', 'derpp_classIL_MNIST': '4s6brbpu', 'derpp_classIL_MNIST_Encoded': '27q0bnf9', 'derpp_classIL_TinyImageNet': 't6fxfkja', 'derpp_classIL_CIFAR5Task': 'zxom37nt'}\n",
      "45\n"
     ]
    }
   ],
   "source": [
    "import wandb\n",
    "\n",
    "WANDB_DIR = Path(\"./wandb\")\n",
    "\n",
    "api = wandb.Api()\n",
    "runs = api.runs(\"equilibrium-fisher-control/icml-final-runs\")\n",
    "\n",
    "sweep_dict = {}\n",
    "for run in runs:\n",
    "    if run.sweep:\n",
    "        sweep_dict[run.sweep.name] = run.sweep.id\n",
    "\n",
    "SWEEP_DICTS = sweep_dict\n",
    "print(SWEEP_DICTS)\n",
    "print(len(SWEEP_DICTS))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9396b36b",
   "metadata": {},
   "source": [
    "# Single Sweep Analysis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "munhs5uzwtc",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading sweep: 2y74n3t7\n",
      "=== Summary for 2y74n3t7 ===\n",
      "Average Accuracy: 28.96% ± 0.84%\n",
      "Number of seeds: 10\n",
      "\n",
      "=== Individual Run Performance ===\n",
      " seed method            setting     lr  beta_efc  target_lr  final_avg_accuracy\n",
      "    0    efc TaskILTinyImageNet 0.0001       0.1        0.1               27.55\n",
      "    1    efc TaskILTinyImageNet 0.0001       0.1        0.1               29.71\n",
      "    2    efc TaskILTinyImageNet 0.0001       0.1        0.1               28.29\n",
      "    3    efc TaskILTinyImageNet 0.0001       0.1        0.1               29.91\n",
      "    4    efc TaskILTinyImageNet 0.0001       0.1        0.1               27.83\n",
      "    5    efc TaskILTinyImageNet 0.0001       0.1        0.1               28.66\n",
      "    6    efc TaskILTinyImageNet 0.0001       0.1        0.1               29.50\n",
      "    7    efc TaskILTinyImageNet 0.0001       0.1        0.1               29.32\n",
      "    8    efc TaskILTinyImageNet 0.0001       0.1        0.1               29.07\n",
      "    9    efc TaskILTinyImageNet 0.0001       0.1        0.1               29.77\n"
     ]
    }
   ],
   "source": [
    "# Summary statistics\n",
    "\n",
    "wanted_sweep = \"2y74n3t7\"\n",
    "\n",
    "df = load_sweep_results_with_hyperparams(sweep_ids=[wanted_sweep], wandb_dir=WANDB_DIR)\n",
    "\n",
    "avg_acc = df['final_avg_accuracy'].mean()\n",
    "std_acc = df['final_avg_accuracy'].std()\n",
    "n_seeds = df['seed'].nunique()\n",
    "\n",
    "print(f\"=== Summary for {wanted_sweep} ===\")\n",
    "print(f\"Average Accuracy: {avg_acc:.2f}% ± {std_acc:.2f}%\")\n",
    "print(f\"Number of seeds: {n_seeds}\")\n",
    "print()\n",
    "\n",
    "# Performance of all models in the sweep\n",
    "print(f\"=== Individual Run Performance ===\")\n",
    "performance_df = df[['seed', 'method', 'setting', 'lr', 'beta_efc', 'target_lr', 'final_avg_accuracy']].copy()\n",
    "performance_df = performance_df.sort_values('seed')\n",
    "performance_df['final_avg_accuracy'] = performance_df['final_avg_accuracy'].round(2)\n",
    "print(performance_df.to_string(index=False))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4ad440a4",
   "metadata": {},
   "source": [
    "## Other Methods "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "f21b16d6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading sweep: fxyfl702\n",
      "=== Summary for bp_taskIL_MNIST ===\n",
      "\n",
      "        Acc Mean  Acc Std  N Seeds\n",
      "method                            \n",
      "bp         98.14     0.09        5\n",
      "\n",
      "=== Individual Run Performance ===\n",
      "Loading sweep: dwy7u31i\n",
      "=== Summary for bp_taskIL_MNIST_Encoded ===\n",
      "\n",
      "        Acc Mean  Acc Std  N Seeds\n",
      "method                            \n",
      "bp         90.73     4.28        5\n",
      "\n",
      "=== Individual Run Performance ===\n",
      "Loading sweep: uxwfc030\n",
      "=== Summary for bp_taskIL_CIFAR10 ===\n",
      "\n",
      "        Acc Mean  Acc Std  N Seeds\n",
      "method                            \n",
      "bp         95.46     0.51        5\n",
      "\n",
      "=== Individual Run Performance ===\n",
      "Loading sweep: lnxoup5r\n",
      "=== Summary for bp_taskIL_TinyImageNet ===\n",
      "\n",
      "        Acc Mean  Acc Std  N Seeds\n",
      "method                            \n",
      "bp         28.01     0.52        5\n",
      "\n",
      "=== Individual Run Performance ===\n",
      "Loading sweep: na9m91mn\n",
      "=== Summary for bp_classIL_MNIST_Encoded ===\n",
      "\n",
      "        Acc Mean  Acc Std  N Seeds\n",
      "method                            \n",
      "bp         19.25     0.18        5\n",
      "\n",
      "=== Individual Run Performance ===\n",
      "Loading sweep: l3igjcsf\n",
      "=== Summary for bp_classIL_MNIST ===\n",
      "\n",
      "        Acc Mean  Acc Std  N Seeds\n",
      "method                            \n",
      "bp          19.3     0.04        5\n",
      "\n",
      "=== Individual Run Performance ===\n",
      "Loading sweep: 4w31ji4p\n",
      "=== Summary for bp_classIL_TinyImageNet ===\n",
      "\n",
      "        Acc Mean  Acc Std  N Seeds\n",
      "method                            \n",
      "bp          3.79     0.09        5\n",
      "\n",
      "=== Individual Run Performance ===\n",
      "Loading sweep: jt0sbq84\n",
      "=== Summary for bp_classIL_CIFAR5Task ===\n",
      "\n",
      "        Acc Mean  Acc Std  N Seeds\n",
      "method                            \n",
      "bp         19.57     0.07        5\n",
      "\n",
      "=== Individual Run Performance ===\n",
      "Loading sweep: 624uewad\n",
      "=== Summary for ewc_taskIL_MNIST_Encoded ===\n",
      "\n",
      "        Acc Mean  Acc Std  N Seeds\n",
      "method                            \n",
      "ewc        96.27     1.63        5\n",
      "\n",
      "=== Individual Run Performance ===\n",
      "Loading sweep: ynthaqge\n",
      "=== Summary for ewc_taskIL_MNIST ===\n",
      "\n",
      "        Acc Mean  Acc Std  N Seeds\n",
      "method                            \n",
      "ewc        97.21     0.64        5\n",
      "\n",
      "=== Individual Run Performance ===\n",
      "Loading sweep: pv9fmucf\n",
      "=== Summary for ewc_taskIL_CIFAR10 ===\n",
      "\n",
      "        Acc Mean  Acc Std  N Seeds\n",
      "method                            \n",
      "ewc        95.47     1.23        5\n",
      "\n",
      "=== Individual Run Performance ===\n",
      "Loading sweep: z80oh9jt\n",
      "=== Summary for ewc_taskIL_TinyImageNet ===\n",
      "\n",
      "        Acc Mean  Acc Std  N Seeds\n",
      "method                            \n",
      "ewc         34.2     0.41        5\n",
      "\n",
      "=== Individual Run Performance ===\n",
      "Loading sweep: lmndtcgs\n",
      "=== Summary for ewc_classIL_MNIST_Encoded ===\n",
      "\n",
      "        Acc Mean  Acc Std  N Seeds\n",
      "method                            \n",
      "ewc        19.61     0.16        5\n",
      "\n",
      "=== Individual Run Performance ===\n",
      "Loading sweep: 9vtbd28v\n",
      "=== Summary for ewc_classIL_TinyImageNet ===\n",
      "\n",
      "        Acc Mean  Acc Std  N Seeds\n",
      "method                            \n",
      "ewc         4.64     0.05        5\n",
      "\n",
      "=== Individual Run Performance ===\n",
      "Loading sweep: 2yy94mkk\n",
      "=== Summary for ewc_classIL_CIFAR5Task ===\n",
      "\n",
      "        Acc Mean  Acc Std  N Seeds\n",
      "method                            \n",
      "ewc        20.87     1.08        5\n",
      "\n",
      "=== Individual Run Performance ===\n",
      "Loading sweep: 5li65lf9\n",
      "=== Summary for oewc_taskIL_MNIST ===\n",
      "\n",
      "        Acc Mean  Acc Std  N Seeds\n",
      "method                            \n",
      "oewc       99.16     0.39        5\n",
      "\n",
      "=== Individual Run Performance ===\n",
      "Loading sweep: haxoiwpy\n",
      "=== Summary for oewc_taskIL_MNIST_Encoded ===\n",
      "\n",
      "        Acc Mean  Acc Std  N Seeds\n",
      "method                            \n",
      "oewc       98.08     1.73        5\n",
      "\n",
      "=== Individual Run Performance ===\n",
      "Loading sweep: k4ty6g4y\n",
      "=== Summary for oewc_taskIL_CIFAR10 ===\n",
      "\n",
      "        Acc Mean  Acc Std  N Seeds\n",
      "method                            \n",
      "oewc       95.87     0.68        5\n",
      "\n",
      "=== Individual Run Performance ===\n",
      "Loading sweep: 1feg452a\n",
      "=== Summary for oewc_taskIL_TinyImageNet ===\n",
      "\n",
      "        Acc Mean  Acc Std  N Seeds\n",
      "method                            \n",
      "oewc       35.54     0.29        5\n",
      "\n",
      "=== Individual Run Performance ===\n",
      "Loading sweep: 96rtma30\n",
      "=== Summary for oewc_classIL_MNIST ===\n",
      "\n",
      "        Acc Mean  Acc Std  N Seeds\n",
      "method                            \n",
      "oewc        20.2     0.84        5\n",
      "\n",
      "=== Individual Run Performance ===\n",
      "Loading sweep: 0uwwn7pg\n",
      "=== Summary for oewc_classIL_MNIST_Encoded ===\n",
      "\n",
      "        Acc Mean  Acc Std  N Seeds\n",
      "method                            \n",
      "oewc       19.64     0.05        5\n",
      "\n",
      "=== Individual Run Performance ===\n",
      "Loading sweep: 5orh3y5r\n",
      "=== Summary for oewc_classIL_CIFAR5Task ===\n",
      "\n",
      "        Acc Mean  Acc Std  N Seeds\n",
      "method                            \n",
      "oewc       21.17     2.02        5\n",
      "\n",
      "=== Individual Run Performance ===\n",
      "Loading sweep: aqqn6apn\n",
      "=== Summary for oewc_classIL_TinyImageNet ===\n",
      "\n",
      "        Acc Mean  Acc Std  N Seeds\n",
      "method                            \n",
      "oewc        4.75     0.04        5\n",
      "\n",
      "=== Individual Run Performance ===\n",
      "Loading sweep: lx9cnj04\n",
      "=== Summary for si_taskIL_MNIST ===\n",
      "\n",
      "        Acc Mean  Acc Std  N Seeds\n",
      "method                            \n",
      "si         97.52     0.93        5\n",
      "\n",
      "=== Individual Run Performance ===\n",
      "Loading sweep: t814a2bw\n",
      "=== Summary for si_taskIL_MNIST_Encoded ===\n",
      "\n",
      "        Acc Mean  Acc Std  N Seeds\n",
      "method                            \n",
      "si          98.5     0.33        5\n",
      "\n",
      "=== Individual Run Performance ===\n",
      "Loading sweep: w4tyfd49\n",
      "=== Summary for si_taskIL_TinyImageNet ===\n",
      "\n",
      "        Acc Mean  Acc Std  N Seeds\n",
      "method                            \n",
      "si         35.52     0.48        5\n",
      "\n",
      "=== Individual Run Performance ===\n",
      "Loading sweep: yfnqasl9\n",
      "=== Summary for si_taskIL_CIFAR10 ===\n",
      "\n",
      "        Acc Mean  Acc Std  N Seeds\n",
      "method                            \n",
      "si         92.63     1.27        5\n",
      "\n",
      "=== Individual Run Performance ===\n",
      "Loading sweep: 0hkmme2n\n",
      "=== Summary for si_classIL_MNIST ===\n",
      "\n",
      "        Acc Mean  Acc Std  N Seeds\n",
      "method                            \n",
      "si         19.72     0.01        5\n",
      "\n",
      "=== Individual Run Performance ===\n",
      "Loading sweep: su8091cg\n",
      "=== Summary for si_classIL_MNIST_Encoded ===\n",
      "\n",
      "        Acc Mean  Acc Std  N Seeds\n",
      "method                            \n",
      "si         19.61     0.09        5\n",
      "\n",
      "=== Individual Run Performance ===\n",
      "Loading sweep: 6upykwgh\n",
      "=== Summary for si_classIL_TinyImageNet ===\n",
      "\n",
      "        Acc Mean  Acc Std  N Seeds\n",
      "method                            \n",
      "si          4.74     0.04        5\n",
      "\n",
      "=== Individual Run Performance ===\n",
      "Loading sweep: 8iacqh0q\n",
      "=== Summary for si_classIL_CIFAR5Task ===\n",
      "\n",
      "        Acc Mean  Acc Std  N Seeds\n",
      "method                            \n",
      "si         19.81     0.13        5\n",
      "\n",
      "=== Individual Run Performance ===\n",
      "Loading sweep: sboxxxxq\n",
      "=== Summary for taskIL_MNIST_EFC ===\n",
      "\n",
      "        Acc Mean  Acc Std  N Seeds\n",
      "method                            \n",
      "efc         96.8     2.45        5\n",
      "\n",
      "=== Individual Run Performance ===\n",
      "Loading sweep: v6nmf69q\n",
      "=== Summary for taskIL_EFC_MNIST_Encoded ===\n",
      "\n",
      "        Acc Mean  Acc Std  N Seeds\n",
      "method                            \n",
      "efc        87.53     4.85        5\n",
      "\n",
      "=== Individual Run Performance ===\n",
      "Loading sweep: tdjyv5su\n",
      "=== Summary for taskIL_EFC_CIFAR10 ===\n",
      "\n",
      "        Acc Mean  Acc Std  N Seeds\n",
      "method                            \n",
      "efc        90.61     3.19        5\n",
      "\n",
      "=== Individual Run Performance ===\n",
      "Loading sweep: i1ejkeu0\n",
      "=== Summary for taskIL_EFC_TinyImageNet ===\n",
      "\n",
      "        Acc Mean  Acc Std  N Seeds\n",
      "method                            \n",
      "efc        36.79     0.41        5\n",
      "\n",
      "=== Individual Run Performance ===\n",
      "Loading sweep: 9nlzz1az\n",
      "=== Summary for ClassIL_EFC_MNIST ===\n",
      "\n",
      "        Acc Mean  Acc Std  N Seeds\n",
      "method                            \n",
      "efc        42.36     2.97        5\n",
      "\n",
      "=== Individual Run Performance ===\n",
      "Loading sweep: k4by5mb7\n",
      "=== Summary for ClassIL_EFC_MNIST_Encoded ===\n",
      "\n",
      "        Acc Mean  Acc Std  N Seeds\n",
      "method                            \n",
      "efc        23.98     5.59        4\n",
      "\n",
      "=== Individual Run Performance ===\n",
      "Loading sweep: hvjx1v04\n",
      "=== Summary for derpp_taskIL_MNIST ===\n",
      "\n",
      "        Acc Mean  Acc Std  N Seeds\n",
      "method                            \n",
      "derpp      98.62     0.17        5\n",
      "\n",
      "=== Individual Run Performance ===\n",
      "Loading sweep: wix5sln6\n",
      "=== Summary for derpp_taskIL_MNIST_Encoded ===\n",
      "\n",
      "        Acc Mean  Acc Std  N Seeds\n",
      "method                            \n",
      "derpp      99.21     0.12        5\n",
      "\n",
      "=== Individual Run Performance ===\n",
      "Loading sweep: 2rpifx70\n",
      "=== Summary for derpp_taskIL_CIFAR10 ===\n",
      "\n",
      "        Acc Mean  Acc Std  N Seeds\n",
      "method                            \n",
      "derpp      95.37     0.79        5\n",
      "\n",
      "=== Individual Run Performance ===\n",
      "Loading sweep: 0xu65xiy\n",
      "=== Summary for derpp_taskIL_TinyImageNet ===\n",
      "\n",
      "        Acc Mean  Acc Std  N Seeds\n",
      "method                            \n",
      "derpp      19.66     0.44        5\n",
      "\n",
      "=== Individual Run Performance ===\n",
      "Loading sweep: 4s6brbpu\n",
      "=== Summary for derpp_classIL_MNIST ===\n",
      "\n",
      "        Acc Mean  Acc Std  N Seeds\n",
      "method                            \n",
      "derpp      71.12     2.16        5\n",
      "\n",
      "=== Individual Run Performance ===\n",
      "Loading sweep: 27q0bnf9\n",
      "=== Summary for derpp_classIL_MNIST_Encoded ===\n",
      "\n",
      "        Acc Mean  Acc Std  N Seeds\n",
      "method                            \n",
      "derpp       91.4     1.19        5\n",
      "\n",
      "=== Individual Run Performance ===\n",
      "Loading sweep: t6fxfkja\n",
      "=== Summary for derpp_classIL_TinyImageNet ===\n",
      "\n",
      "        Acc Mean  Acc Std  N Seeds\n",
      "method                            \n",
      "derpp       4.96     0.26        5\n",
      "\n",
      "=== Individual Run Performance ===\n",
      "Loading sweep: zxom37nt\n",
      "=== Summary for derpp_classIL_CIFAR5Task ===\n",
      "\n",
      "        Acc Mean  Acc Std  N Seeds\n",
      "method                            \n",
      "derpp      62.26     1.16        5\n",
      "\n",
      "=== Individual Run Performance ===\n"
     ]
    }
   ],
   "source": [
    "# Dictionary mapping sweep names to their IDs for \"other\" methods\n",
    "\n",
    "for sweep in SWEEP_DICTS.keys():\n",
    "    try:\n",
    "        wanted_sweep = sweep\n",
    "        df = load_sweep_results_with_hyperparams([SWEEP_DICTS[wanted_sweep]], WANDB_DIR)\n",
    "\n",
    "        # Summary statistics grouped by method\n",
    "        print(f\"=== Summary for {wanted_sweep} ===\\n\")\n",
    "        summary = df.groupby('method').agg({\n",
    "            'final_avg_accuracy': ['mean', 'std', 'count'],\n",
    "        }).round(2)\n",
    "        summary.columns = ['Acc Mean', 'Acc Std', 'N Seeds']\n",
    "        print(summary.to_string())\n",
    "\n",
    "        print(f\"\\n=== Individual Run Performance ===\")\n",
    "        # performance_df = df[['seed', 'method', 'setting', 'final_avg_accuracy']].copy()\n",
    "        # performance_df = performance_df.sort_values(['method', 'seed'])\n",
    "        # performance_df['final_avg_accuracy'] = performance_df['final_avg_accuracy'].round(2)\n",
    "        # print(performance_df.to_string(index=False))\n",
    "    except:\n",
    "        continue\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "fd9b0163",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'bp_taskIL_MNIST': 'fxyfl702',\n",
       " 'bp_taskIL_MNIST_Encoded': 'dwy7u31i',\n",
       " 'bp_taskIL_CIFAR10': 'uxwfc030',\n",
       " 'bp_taskIL_TinyImageNet': 'lnxoup5r',\n",
       " 'bp_classIL_MNIST_Encoded': 'na9m91mn',\n",
       " 'bp_classIL_MNIST': 'l3igjcsf',\n",
       " 'bp_classIL_TinyImageNet': '4w31ji4p',\n",
       " 'bp_classIL_CIFAR5Task': 'jt0sbq84',\n",
       " 'ewc_taskIL_MNIST_Encoded': '624uewad',\n",
       " 'ewc_taskIL_MNIST': 'ynthaqge',\n",
       " 'ewc_taskIL_CIFAR10': 'pv9fmucf',\n",
       " 'ewc_taskIL_TinyImageNet': 'z80oh9jt',\n",
       " 'ewc_classIL_MNIST_Encoded': 'lmndtcgs',\n",
       " 'ewc_classIL_TinyImageNet': '9vtbd28v',\n",
       " 'ewc_classIL_CIFAR5Task': '2yy94mkk',\n",
       " 'oewc_taskIL_MNIST': '5li65lf9',\n",
       " 'oewc_taskIL_MNIST_Encoded': 'haxoiwpy',\n",
       " 'oewc_taskIL_CIFAR10': 'k4ty6g4y',\n",
       " 'oewc_taskIL_TinyImageNet': '1feg452a',\n",
       " 'oewc_classIL_MNIST': '96rtma30',\n",
       " 'oewc_classIL_MNIST_Encoded': '0uwwn7pg',\n",
       " 'oewc_classIL_CIFAR5Task': '5orh3y5r',\n",
       " 'oewc_classIL_TinyImageNet': 'aqqn6apn',\n",
       " 'si_taskIL_MNIST': 'lx9cnj04',\n",
       " 'si_taskIL_MNIST_Encoded': 't814a2bw',\n",
       " 'si_taskIL_TinyImageNet': 'w4tyfd49',\n",
       " 'si_taskIL_CIFAR10': 'yfnqasl9',\n",
       " 'si_classIL_MNIST': '0hkmme2n',\n",
       " 'si_classIL_MNIST_Encoded': 'su8091cg',\n",
       " 'si_classIL_TinyImageNet': '6upykwgh',\n",
       " 'si_classIL_CIFAR5Task': '8iacqh0q',\n",
       " 'taskIL_MNIST_EFC': 'sboxxxxq',\n",
       " 'taskIL_EFC_MNIST_Encoded': 'v6nmf69q',\n",
       " 'taskIL_EFC_CIFAR10': 'tdjyv5su',\n",
       " 'taskIL_EFC_TinyImageNet': 'i1ejkeu0',\n",
       " 'ClassIL_EFC_MNIST': '9nlzz1az',\n",
       " 'ClassIL_EFC_MNIST_Encoded': 'k4by5mb7',\n",
       " 'derpp_taskIL_MNIST': 'hvjx1v04',\n",
       " 'derpp_taskIL_MNIST_Encoded': 'wix5sln6',\n",
       " 'derpp_taskIL_CIFAR10': '2rpifx70',\n",
       " 'derpp_taskIL_TinyImageNet': '0xu65xiy',\n",
       " 'derpp_classIL_MNIST': '4s6brbpu',\n",
       " 'derpp_classIL_MNIST_Encoded': '27q0bnf9',\n",
       " 'derpp_classIL_TinyImageNet': 't6fxfkja',\n",
       " 'derpp_classIL_CIFAR5Task': 'zxom37nt'}"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "SWEEP_DICTS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "706b7bd4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading sweep: fxyfl702\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "bp_taskIL_MNIST: BP | taskIL | MNIST -> 98.14 ± 0.09 (n=5)\n",
      "Loading sweep: dwy7u31i\n",
      "bp_taskIL_MNIST_Encoded: BP | taskIL | MNIST_Encoded -> 90.73 ± 4.28 (n=5)\n",
      "Loading sweep: uxwfc030\n",
      "bp_taskIL_CIFAR10: BP | taskIL | CIFAR10 -> 95.46 ± 0.51 (n=5)\n",
      "Loading sweep: lnxoup5r\n",
      "bp_taskIL_TinyImageNet: BP | taskIL | TinyImageNet -> 28.01 ± 0.52 (n=5)\n",
      "Loading sweep: na9m91mn\n",
      "bp_classIL_MNIST_Encoded: BP | classIL | MNIST_Encoded -> 19.25 ± 0.18 (n=5)\n",
      "Loading sweep: l3igjcsf\n",
      "bp_classIL_MNIST: BP | classIL | MNIST -> 19.30 ± 0.04 (n=5)\n",
      "Loading sweep: 4w31ji4p\n",
      "bp_classIL_TinyImageNet: BP | classIL | TinyImageNet -> 3.79 ± 0.09 (n=5)\n",
      "Loading sweep: jt0sbq84\n",
      "bp_classIL_CIFAR5Task: BP | classIL | CIFAR5Task -> 19.57 ± 0.07 (n=5)\n",
      "Loading sweep: 624uewad\n",
      "ewc_taskIL_MNIST_Encoded: EWC | taskIL | MNIST_Encoded -> 96.27 ± 1.63 (n=5)\n",
      "Loading sweep: ynthaqge\n",
      "ewc_taskIL_MNIST: EWC | taskIL | MNIST -> 97.21 ± 0.64 (n=5)\n",
      "Loading sweep: pv9fmucf\n",
      "ewc_taskIL_CIFAR10: EWC | taskIL | CIFAR10 -> 95.47 ± 1.23 (n=5)\n",
      "Loading sweep: z80oh9jt\n",
      "ewc_taskIL_TinyImageNet: EWC | taskIL | TinyImageNet -> 34.20 ± 0.41 (n=5)\n",
      "Loading sweep: lmndtcgs\n",
      "ewc_classIL_MNIST_Encoded: EWC | classIL | MNIST_Encoded -> 19.61 ± 0.16 (n=5)\n",
      "Loading sweep: 9vtbd28v\n",
      "ewc_classIL_TinyImageNet: EWC | classIL | TinyImageNet -> 4.64 ± 0.05 (n=5)\n",
      "Loading sweep: 2yy94mkk\n",
      "ewc_classIL_CIFAR5Task: EWC | classIL | CIFAR5Task -> 20.87 ± 1.08 (n=5)\n",
      "Loading sweep: 5li65lf9\n",
      "oewc_taskIL_MNIST: oEWC | taskIL | MNIST -> 99.16 ± 0.39 (n=5)\n",
      "Loading sweep: haxoiwpy\n",
      "oewc_taskIL_MNIST_Encoded: oEWC | taskIL | MNIST_Encoded -> 98.08 ± 1.73 (n=5)\n",
      "Loading sweep: k4ty6g4y\n",
      "oewc_taskIL_CIFAR10: oEWC | taskIL | CIFAR10 -> 95.87 ± 0.68 (n=5)\n",
      "Loading sweep: 1feg452a\n",
      "oewc_taskIL_TinyImageNet: oEWC | taskIL | TinyImageNet -> 35.54 ± 0.29 (n=5)\n",
      "Loading sweep: 96rtma30\n",
      "oewc_classIL_MNIST: oEWC | classIL | MNIST -> 20.20 ± 0.84 (n=5)\n",
      "Loading sweep: 0uwwn7pg\n",
      "oewc_classIL_MNIST_Encoded: oEWC | classIL | MNIST_Encoded -> 19.64 ± 0.05 (n=5)\n",
      "Loading sweep: 5orh3y5r\n",
      "oewc_classIL_CIFAR5Task: oEWC | classIL | CIFAR5Task -> 21.17 ± 2.02 (n=5)\n",
      "Loading sweep: aqqn6apn\n",
      "oewc_classIL_TinyImageNet: oEWC | classIL | TinyImageNet -> 4.75 ± 0.04 (n=5)\n",
      "Loading sweep: lx9cnj04\n",
      "si_taskIL_MNIST: SI | taskIL | MNIST -> 97.52 ± 0.93 (n=5)\n",
      "Loading sweep: t814a2bw\n",
      "si_taskIL_MNIST_Encoded: SI | taskIL | MNIST_Encoded -> 98.50 ± 0.33 (n=5)\n",
      "Loading sweep: w4tyfd49\n",
      "si_taskIL_TinyImageNet: SI | taskIL | TinyImageNet -> 35.52 ± 0.48 (n=5)\n",
      "Loading sweep: yfnqasl9\n",
      "si_taskIL_CIFAR10: SI | taskIL | CIFAR10 -> 92.63 ± 1.27 (n=5)\n",
      "Loading sweep: 0hkmme2n\n",
      "si_classIL_MNIST: SI | classIL | MNIST -> 19.72 ± 0.01 (n=5)\n",
      "Loading sweep: su8091cg\n",
      "si_classIL_MNIST_Encoded: SI | classIL | MNIST_Encoded -> 19.61 ± 0.09 (n=5)\n",
      "Loading sweep: 6upykwgh\n",
      "si_classIL_TinyImageNet: SI | classIL | TinyImageNet -> 4.74 ± 0.04 (n=5)\n",
      "Loading sweep: 8iacqh0q\n",
      "si_classIL_CIFAR5Task: SI | classIL | CIFAR5Task -> 19.81 ± 0.13 (n=5)\n",
      "Loading sweep: sboxxxxq\n",
      "taskIL_MNIST_EFC: EFC | taskIL | MNIST -> 96.80 ± 2.45 (n=5)\n",
      "Loading sweep: v6nmf69q\n",
      "taskIL_EFC_MNIST_Encoded: EFC | taskIL | MNIST_Encoded -> 87.53 ± 4.85 (n=5)\n",
      "Loading sweep: tdjyv5su\n",
      "taskIL_EFC_CIFAR10: EFC | taskIL | CIFAR10 -> 90.61 ± 3.19 (n=5)\n",
      "Loading sweep: i1ejkeu0\n",
      "taskIL_EFC_TinyImageNet: EFC | taskIL | TinyImageNet -> 36.79 ± 0.41 (n=5)\n",
      "Loading sweep: 9nlzz1az\n",
      "ClassIL_EFC_MNIST: EFC | ClassIL | MNIST -> 42.36 ± 2.97 (n=5)\n",
      "Loading sweep: k4by5mb7\n",
      "ClassIL_EFC_MNIST_Encoded: EFC | ClassIL | MNIST_Encoded -> 23.98 ± 5.59 (n=5)\n",
      "Loading sweep: hvjx1v04\n",
      "derpp_taskIL_MNIST: DERPP | taskIL | MNIST -> 98.62 ± 0.17 (n=5)\n",
      "Loading sweep: wix5sln6\n",
      "derpp_taskIL_MNIST_Encoded: DERPP | taskIL | MNIST_Encoded -> 99.21 ± 0.12 (n=5)\n",
      "Loading sweep: 2rpifx70\n",
      "derpp_taskIL_CIFAR10: DERPP | taskIL | CIFAR10 -> 95.37 ± 0.79 (n=5)\n",
      "Loading sweep: 0xu65xiy\n",
      "derpp_taskIL_TinyImageNet: DERPP | taskIL | TinyImageNet -> 19.66 ± 0.44 (n=5)\n",
      "Loading sweep: 4s6brbpu\n",
      "derpp_classIL_MNIST: DERPP | classIL | MNIST -> 71.12 ± 2.16 (n=5)\n",
      "Loading sweep: 27q0bnf9\n",
      "derpp_classIL_MNIST_Encoded: DERPP | classIL | MNIST_Encoded -> 91.40 ± 1.19 (n=5)\n",
      "Loading sweep: t6fxfkja\n",
      "derpp_classIL_TinyImageNet: DERPP | classIL | TinyImageNet -> 4.96 ± 0.26 (n=5)\n",
      "Loading sweep: zxom37nt\n",
      "derpp_classIL_CIFAR5Task: DERPP | classIL | CIFAR5Task -> 62.26 ± 1.16 (n=5)\n",
      "\n",
      "================================================================================\n",
      "RESULTS TABLE\n",
      "================================================================================\n",
      "Setting        taskIL                                                            classIL                                                        ClassIL              \n",
      "Dataset         MNIST MNIST_Encoded       CIFAR10  TinyImageNet CIFAR5Task         MNIST MNIST_Encoded CIFAR10 TinyImageNet    CIFAR5Task         MNIST MNIST_Encoded\n",
      "BP       98.14 ± 0.09  90.73 ± 4.28  95.46 ± 0.51  28.01 ± 0.52        NaN  19.30 ± 0.04  19.25 ± 0.18     NaN  3.79 ± 0.09  19.57 ± 0.07           NaN           NaN\n",
      "EWC      97.21 ± 0.64  96.27 ± 1.63  95.47 ± 1.23  34.20 ± 0.41        NaN           NaN  19.61 ± 0.16     NaN  4.64 ± 0.05  20.87 ± 1.08           NaN           NaN\n",
      "oEWC     99.16 ± 0.39  98.08 ± 1.73  95.87 ± 0.68  35.54 ± 0.29        NaN  20.20 ± 0.84  19.64 ± 0.05     NaN  4.75 ± 0.04  21.17 ± 2.02           NaN           NaN\n",
      "SI       97.52 ± 0.93  98.50 ± 0.33  92.63 ± 1.27  35.52 ± 0.48        NaN  19.72 ± 0.01  19.61 ± 0.09     NaN  4.74 ± 0.04  19.81 ± 0.13           NaN           NaN\n",
      "EFC      96.80 ± 2.45  87.53 ± 4.85  90.61 ± 3.19  36.79 ± 0.41        NaN           NaN           NaN     NaN          NaN           NaN  42.36 ± 2.97  23.98 ± 5.59\n",
      "DERPP    98.62 ± 0.17  99.21 ± 0.12  95.37 ± 0.79  19.66 ± 0.44        NaN  71.12 ± 2.16  91.40 ± 1.19     NaN  4.96 ± 0.26  62.26 ± 1.16           NaN           NaN\n"
     ]
    }
   ],
   "source": [
    "# Build results table from all sweeps\n",
    "# Sweep names follow pattern: {method}_{setting}_{dataset} or {setting}_{method}_{dataset}\n",
    "\n",
    "# Define the table structure\n",
    "datasets = [\"MNIST\", \"MNIST_Encoded\", \"CIFAR10\", \"TinyImageNet\", \"CIFAR5Task\"]\n",
    "settings = [\"taskIL\", \"classIL\"]\n",
    "methods = [\"BP\", \"EWC\", \"oEWC\", \"SI\", \"EFC\", \"DERPP\"]\n",
    "\n",
    "# Method name mapping (lowercase -> display name)\n",
    "method_map = {\"bp\": \"BP\", \"ewc\": \"EWC\", \"oewc\": \"oEWC\", \"si\": \"SI\", \"efc\": \"EFC\", \"derpp\": \"DERPP\"}\n",
    "\n",
    "# Create multi-level columns: (setting, dataset)\n",
    "columns = pd.MultiIndex.from_product([settings, datasets], names=[\"Setting\", \"Dataset\"])\n",
    "results_df = pd.DataFrame(index=methods, columns=columns)\n",
    "\n",
    "# Parse sweep names and load results\n",
    "for sweep_name, sweep_id in SWEEP_DICTS.items():\n",
    "    # Parse the sweep name to extract method, setting, dataset\n",
    "    parts = sweep_name.split(\"_\")\n",
    "    \n",
    "    # Handle both naming conventions:\n",
    "    # \"bp_taskIL_MNIST\" -> method=bp, setting=taskIL, dataset=MNIST\n",
    "    # \"taskIL_MNIST_EFC\" -> setting=taskIL, dataset=MNIST, method=EFC\n",
    "    if parts[0].lower() in method_map:\n",
    "        # Format: method_setting_dataset\n",
    "        method_key = parts[0].lower()\n",
    "        setting = parts[1]\n",
    "        dataset = \"_\".join(parts[2:])\n",
    "    else:\n",
    "        # Format: setting_method_dataset or setting_dataset_method\n",
    "        setting = parts[0]\n",
    "        if parts[-1].upper() in [\"EFC\"]:\n",
    "            method_key = parts[-1].lower()\n",
    "            dataset = \"_\".join(parts[1:-1])\n",
    "        else:\n",
    "            method_key = parts[1].lower()\n",
    "            dataset = \"_\".join(parts[2:])\n",
    "    \n",
    "    method_display = method_map.get(method_key, method_key.upper())\n",
    "    \n",
    "    # Skip if dataset not in our list\n",
    "    if dataset not in datasets:\n",
    "        print(f\"Skipping {sweep_name}: dataset '{dataset}' not recognized\")\n",
    "        continue\n",
    "    \n",
    "    # Load the sweep data\n",
    "    df = load_sweep_results_with_hyperparams([sweep_id], WANDB_DIR)\n",
    "    \n",
    "    if df.empty:\n",
    "        print(f\"No data for {sweep_name}\")\n",
    "        continue\n",
    "    \n",
    "    # Calculate mean ± std\n",
    "    mean_acc = df['final_avg_accuracy'].mean()\n",
    "    std_acc = df['final_avg_accuracy'].std()\n",
    "    n_runs = len(df)\n",
    "    \n",
    "    # Store in results table\n",
    "    results_df.loc[method_display, (setting, dataset)] = f\"{mean_acc:.2f} ± {std_acc:.2f}\"\n",
    "    print(f\"{sweep_name}: {method_display} | {setting} | {dataset} -> {mean_acc:.2f} ± {std_acc:.2f} (n={n_runs})\")\n",
    "\n",
    "print(\"\\n\" + \"=\"*80)\n",
    "print(\"RESULTS TABLE\")\n",
    "print(\"=\"*80)\n",
    "print(results_df.to_string())\n",
    "# Convert to CSV \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "slhpbh8kdah",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saved to paper_results.csv\n"
     ]
    }
   ],
   "source": [
    "# Save results to CSV\n",
    "results_df.to_csv(\"paper_results.csv\")\n",
    "print(\"Saved to paper_results.csv\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "efc",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
