{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "gpuType": "A100"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "code",
      "source": [
        "# from google.colab import drive\n",
        "# drive.mount('/content/drive')"
      ],
      "metadata": {
        "id": "cBZ4vq9F7C0a"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.optim as optim\n",
        "from torch.utils.data import Subset, DataLoader\n",
        "import sys\n",
        "import os\n",
        "import pickle\n",
        "\n",
        "PROJECT_PATH = '/content/drive/MyDrive/tfc-sr'\n",
        "sys.path.append(PROJECT_PATH)"
      ],
      "metadata": {
        "id": "huAqHGqg-FPf"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "CONFIG = {\n",
        "    'seed': 42,\n",
        "    'num_tasks': 5,\n",
        "    'epochs_per_task': 10,\n",
        "    'batch_size': 64,\n",
        "    'lr': 0.001,\n",
        "    'num_classes': 10,\n",
        "    'results_path': os.path.join(PROJECT_PATH, 'results'),\n",
        "    'checkpoints_path': os.path.join(PROJECT_PATH, 'checkpoints'),\n",
        "    # EWC/SI specific\n",
        "    'ewc_lambda': 1.0, # Regularization strength for EWC.\n",
        "    'si_lambda': 1.0,   # Regularization strength for SI.\n",
        "}"
      ],
      "metadata": {
        "id": "37XWfgzjYXBH"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Create directories if they don't exist\n",
        "os.makedirs(CONFIG['results_path'], exist_ok=True)\n",
        "os.makedirs(CONFIG['checkpoints_path'], exist_ok=True)"
      ],
      "metadata": {
        "id": "o-QU5MftYlcQ"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "print(f\"Using device: {device}\")"
      ],
      "metadata": {
        "id": "dV1IpHdrYpoP"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "!pip install avalanche-lib\n",
        "\n",
        "from avalanche.benchmarks.classic import SplitMNIST\n",
        "from avalanche.training import EWC, SynapticIntelligence\n",
        "from avalanche.evaluation.metrics import accuracy_metrics, loss_metrics, forgetting_metrics\n",
        "from avalanche.logging import InteractiveLogger, TextLogger\n",
        "from avalanche.training.plugins import EvaluationPlugin\n",
        "\n",
        "\n",
        "# --- 3. UNIFIED AVALANCHE BENCHMARK SETUP ---\n",
        "# This benchmark will be used for all experiments to ensure consistency.\n",
        "split_mnist_benchmark = SplitMNIST(n_experiences=5, seed=CONFIG['seed'])"
      ],
      "metadata": {
        "id": "e01aPD6r3lrd"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "from model import CNN\n",
        "from data_setup import get_split_mnist_dataloaders\n",
        "from utils import set_seed, save_results, plot_results, load_results, evaluate_on_seen_tasks"
      ],
      "metadata": {
        "id": "CzszrKzwGKIU"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "set_seed(CONFIG['seed'])"
      ],
      "metadata": {
        "id": "vYdEBc_kYsWW"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# --- EXPERIMENT: Standard CL (Benchmark) ---\n",
        "\n",
        "model = CNN(num_classes=CONFIG['num_classes']).to(device)\n",
        "optimizer = optim.Adam(model.parameters(), lr=CONFIG['lr'])\n",
        "criterion = nn.CrossEntropyLoss()\n",
        "baseline_accuracies = []\n",
        "\n",
        "# Main continual learning loop\n",
        "for i, experience in enumerate(split_mnist_benchmark.train_stream):\n",
        "    print(f\"\\n--- Training on Task {i+1}/{CONFIG['num_tasks']} ---\")\n",
        "\n",
        "    train_dataset = experience.dataset\n",
        "    train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'], shuffle=True)\n",
        "\n",
        "    # Training loop\n",
        "    model.train()\n",
        "    for epoch in range(CONFIG['epochs_per_task']):\n",
        "        for data, targets, task_labels in train_loader:\n",
        "            data, targets = data.to(device), targets.to(device)\n",
        "\n",
        "            optimizer.zero_grad()\n",
        "            outputs = model(data)\n",
        "            loss = criterion(outputs, targets)\n",
        "            loss.backward()\n",
        "            optimizer.step()\n",
        "        print(f\"Task {i+1}, Epoch {epoch+1}/{CONFIG['epochs_per_task']}, Loss: {loss.item():.4f}\")\n",
        "\n",
        "    # Evaluation loop\n",
        "    accuracy = evaluate_on_seen_tasks(model, split_mnist_benchmark, i, device, CONFIG['batch_size'])\n",
        "    baseline_accuracies.append(accuracy)\n",
        "    print(f\"----- Accuracy after Task {i+1}: {accuracy:.2f}% -----\")\n",
        "\n",
        "# Save the final model checkpoint\n",
        "final_model_path = os.path.join(CONFIG['checkpoints_path'], 'baseline_final_model.pth')\n",
        "torch.save(model.state_dict(), final_model_path)\n",
        "print(f\"\\nFinal baseline model saved to {final_model_path}\")\n",
        "\n",
        "# Save the results list\n",
        "baseline_results_path = os.path.join(CONFIG['results_path'], 'baseline_accuracies.pkl')\n",
        "save_results(baseline_accuracies, baseline_results_path)\n",
        "\n",
        "# Plot the results\n",
        "results_to_plot = {\n",
        "    'Baseline': baseline_accuracies\n",
        "}\n",
        "plot_results(results_to_plot, title=\"Baseline Performance on Split MNIST\")"
      ],
      "metadata": {
        "id": "8KjyRRV-Y5PX"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "from utils import ReservoirReplayBuffer\n",
        "\n",
        "CONFIG['buffer_capacity'] = 200 # Total samples to store across all tasks\n",
        "CONFIG['replay_batch_size'] = CONFIG['batch_size'] // 2 # Should be half of the main batch_size"
      ],
      "metadata": {
        "id": "k6nNLUCprmO0"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# --- EXPERIMENT: Standard ER ---\n",
        "set_seed(CONFIG['seed'])\n",
        "\n",
        "model_er = CNN(num_classes=CONFIG['num_classes']).to(device)\n",
        "optimizer_er = optim.Adam(model_er.parameters(), lr=CONFIG['lr'])\n",
        "criterion_er = nn.CrossEntropyLoss()\n",
        "\n",
        "replay_buffer_er = ReservoirReplayBuffer(capacity=CONFIG['buffer_capacity'])\n",
        "\n",
        "er_accuracies = []\n",
        "\n",
        "# Main continual learning loop\n",
        "for task_id, experience in enumerate(split_mnist_benchmark.train_stream):\n",
        "    print(f\"\\n--- Training on Task {task_id+1}/{len(split_mnist_benchmark.train_stream)} ---\")\n",
        "\n",
        "    # --- Step A: Populate the replay buffer with some examples from the new task ---\n",
        "    # We do this before training on the task itself\n",
        "    print(f\"Populating replay buffer from Task {task_id+1}...\")\n",
        "\n",
        "    for data_point, target, _ in experience.dataset:\n",
        "      replay_buffer_er.add(data_point, target)\n",
        "    print(f\"Replay buffer size: {len(replay_buffer_er)}\")\n",
        "\n",
        "    # --- Step B: Training loop with mixed batches ---\n",
        "    model_er.train()\n",
        "    for epoch in range(CONFIG['epochs_per_task']):\n",
        "        for new_data, new_targets, _ in train_loader:\n",
        "            # Only proceed if we have something in the buffer to replay\n",
        "            if len(replay_buffer_er) > CONFIG['replay_batch_size']:\n",
        "                # 1. Sample a batch from the replay buffer\n",
        "                old_data, old_targets = replay_buffer_er.sample(CONFIG['replay_batch_size'])\n",
        "\n",
        "                # 2. Create the mixed batch\n",
        "                # Ensure the new data batch is the same size as the replay batch\n",
        "                # This makes a 50/50 mix\n",
        "                new_data = new_data[:CONFIG['replay_batch_size']]\n",
        "                new_targets = new_targets[:CONFIG['replay_batch_size']]\n",
        "\n",
        "                combined_data = torch.cat((new_data, old_data), dim=0).to(device)\n",
        "                combined_targets = torch.cat((new_targets, old_targets), dim=0).to(device)\n",
        "\n",
        "                # 3. Standard training step on the mixed batch\n",
        "                optimizer_er.zero_grad()\n",
        "                outputs = model_er(combined_data)\n",
        "                loss = criterion_er(outputs, combined_targets)\n",
        "                loss.backward()\n",
        "                optimizer_er.step()\n",
        "\n",
        "        print(f\"Task {task_id+1}, Epoch {epoch+1}, Last batch loss: {loss.item():.4f}\")\n",
        "\n",
        "    # --- Step C: Evaluation loop ---\n",
        "    accuracy = evaluate_on_seen_tasks(model_er, split_mnist_benchmark, task_id, device, CONFIG['batch_size'])\n",
        "    er_accuracies.append(accuracy)\n",
        "    print(f\"----- Accuracy after Task {task_id+1}: {accuracy:.2f}% -----\")\n",
        "\n",
        "# Save checkpoint\n",
        "er_model_path = os.path.join(CONFIG['checkpoints_path'], 'er_final_model.pth')\n",
        "torch.save(model_er.state_dict(), er_model_path)\n",
        "print(f\"\\nFinal ER model saved to {er_model_path}\")\n",
        "\n",
        "# Save results\n",
        "er_results_path = os.path.join(CONFIG['results_path'], 'er_accuracies.pkl')\n",
        "save_results(er_accuracies, er_results_path)\n",
        "\n",
        "# Plot comparison\n",
        "baseline_accuracies = load_results(os.path.join(CONFIG['results_path'], 'baseline_accuracies.pkl'))\n",
        "results_to_plot = {\n",
        "    'Baseline': baseline_accuracies,\n",
        "    'Standard ER': er_accuracies\n",
        "}\n",
        "plot_results(results_to_plot, title=\"Standard ER vs. Baseline on Split MNIST\")"
      ],
      "metadata": {
        "id": "KLgoSjIXaOqO"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# --- EXPERIMENT: Task-Focused Consolidation with Spaced Repetition (TFC-SR) ---\n",
        "set_seed(CONFIG['seed'])\n",
        "\n",
        "from utils import create_buffer_validation_set, evaluate_replay_buffer\n",
        "\n",
        "# --- 1. CONFIGURATION ---\n",
        "# New parameters for TFC-SR\n",
        "CONFIG['mastery_threshold'] = 95.0 # Accuracy threshold in %\n",
        "CONFIG['initial_replay_gap'] = 1   # Start checking after epoch 1\n",
        "CONFIG['replay_gap_multiplier'] = 1.5 # How much to increase the gap\n",
        "\n",
        "# --- 2. RUN TFC-SR EXPERIMENT ---\n",
        "print(\"\\n===== Starting TFC-SR Experiment =====\")\n",
        "\n",
        "model_tfc = CNN(num_classes=CONFIG['num_classes']).to(device)\n",
        "optimizer_tfc = optim.Adam(model_tfc.parameters(), lr=CONFIG['lr'])\n",
        "criterion_tfc = nn.CrossEntropyLoss()\n",
        "\n",
        "replay_buffer_tfc = ReservoirReplayBuffer(capacity=CONFIG['buffer_capacity'])\n",
        "tfc_accuracies = []\n",
        "\n",
        "# --- Main Continual Learning Loop ---\n",
        "for task_id, experience in enumerate(split_mnist_benchmark.train_stream):\n",
        "    print(f\"\\n--- Training on Task {task_id+1}/{len(split_mnist_benchmark.train_stream)} ---\")\n",
        "\n",
        "    # Populate replay buffer\n",
        "    for data_point, target, _ in experience.dataset:\n",
        "        replay_buffer_tfc.add(data_point, target)\n",
        "    print(f\"Replay buffer size: {len(replay_buffer_tfc)}\")\n",
        "\n",
        "    # Initialize the replay schedule for this new task\n",
        "    current_replay_gap = float(CONFIG['initial_replay_gap'])\n",
        "    replay_timer = int(current_replay_gap)\n",
        "\n",
        "    train_loader = DataLoader(experience.dataset, batch_size=CONFIG['batch_size'], shuffle=True)\n",
        "\n",
        "    # --- Training loop for the current task ---\n",
        "    model_tfc.train()\n",
        "    for epoch in range(CONFIG['epochs_per_task']):\n",
        "        # --- Mixed-batch training ---\n",
        "        for new_data, new_targets, _ in train_loader:\n",
        "            if len(replay_buffer_tfc) > CONFIG['batch_size'] // 2:\n",
        "                replay_batch_size = CONFIG['batch_size'] // 2\n",
        "                old_data, old_targets = replay_buffer_tfc.sample(replay_batch_size)\n",
        "                new_data = new_data[:replay_batch_size]\n",
        "                new_targets = new_targets[:replay_batch_size]\n",
        "                combined_data = torch.cat((new_data, old_data), dim=0).to(device)\n",
        "                combined_targets = torch.cat((new_targets, old_targets), dim=0).to(device)\n",
        "\n",
        "                optimizer_tfc.zero_grad()\n",
        "                outputs = model_tfc(combined_data)\n",
        "                loss = criterion_tfc(outputs, combined_targets)\n",
        "                loss.backward()\n",
        "                optimizer_tfc.step()\n",
        "\n",
        "        print(f\"Task {task_id+1}, Epoch {epoch+1}, Loss: {loss.item():.4f}\", end=\"\")\n",
        "\n",
        "        # --- Adaptive Replay Scheduling Logic ---\n",
        "        if (epoch + 1) == replay_timer and len(replay_buffer_tfc) > 1:\n",
        "            print(\" <-- Memory Check!\", end=\"\")\n",
        "            model_tfc.eval()\n",
        "\n",
        "            replay_perf = evaluate_replay_buffer(model_tfc, replay_buffer_tfc, device)\n",
        "\n",
        "            print(f\" | Replay Perf: {replay_perf:.2f}%\", end=\"\")\n",
        "\n",
        "            if replay_perf >= CONFIG['mastery_threshold']:\n",
        "                current_replay_gap *= CONFIG['replay_gap_multiplier']\n",
        "                replay_timer += round(current_replay_gap)\n",
        "                print(f\" | Mastery OK. Next check @ epoch {replay_timer}.\")\n",
        "            else:\n",
        "                replay_timer += 1\n",
        "                print(f\" | Mastery FAIL. Next check @ epoch {replay_timer+1}.\")\n",
        "\n",
        "            model_tfc.train()\n",
        "        else:\n",
        "            print()\n",
        "\n",
        "    # --- Final Evaluation for this task (same as before) ---\n",
        "    accuracy = evaluate_on_seen_tasks(model_tfc, split_mnist_benchmark, task_id, device, CONFIG['batch_size'])\n",
        "    tfc_accuracies.append(accuracy)\n",
        "    print(f\"----- Accuracy after Task {task_id+1}: {accuracy:.2f}% -----\")\n",
        "\n",
        "# --- 3. SAVE AND PLOT ---\n",
        "# Save checkpoint and results\n",
        "tfc_model_path = os.path.join(CONFIG['checkpoints_path'], 'tfc_sr_final_model.pth')\n",
        "torch.save(model_tfc.state_dict(), tfc_model_path)\n",
        "tfc_results_path = os.path.join(CONFIG['results_path'], 'tfc_sr_accuracies.pkl')\n",
        "save_results(tfc_accuracies, tfc_results_path)\n",
        "\n",
        "# Plot comparison with previous results\n",
        "baseline_accuracies = load_results(os.path.join(CONFIG['results_path'], 'baseline_accuracies.pkl'))\n",
        "er_accuracies = load_results(os.path.join(CONFIG['results_path'], 'er_accuracies.pkl'))\n",
        "results_to_plot = {\n",
        "    'Baseline': baseline_accuracies,\n",
        "    'Standard ER': er_accuracies,\n",
        "    'TFC-SR': tfc_accuracies\n",
        "}\n",
        "plot_results(results_to_plot, title=\"TFC-SR vs. Baselines on Split MNIST\")"
      ],
      "metadata": {
        "id": "9iiujqvomf66"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# --- EXPERIMENT: ELASTIC WEIGHT CONSOLIDATION (EWC) ---\n",
        "set_seed(CONFIG['seed'])\n",
        "\n",
        "print(\"\\n\" + \"=\"*20 + \" Starting EWC Experiment \" + \"=\"*20)\n",
        "\n",
        "# --- Hyperparameter Search for EWC ---\n",
        "ewc_lambdas_to_try = [1.0, 100.0, 1000.0, 10000.0]\n",
        "all_ewc_results = {}\n",
        "\n",
        "for lmbda in ewc_lambdas_to_try:\n",
        "    print(f\"\\n--- Running EWC with lambda = {lmbda} ---\")\n",
        "\n",
        "    set_seed(CONFIG['seed'])\n",
        "\n",
        "    # --- Setup strategy for this trial ---\n",
        "    model_ewc = CNN(num_classes=CONFIG['num_classes']).to(device)\n",
        "    optimizer_ewc = optim.Adam(model_ewc.parameters(), lr=CONFIG['lr'])\n",
        "\n",
        "    ewc_strategy = EWC(\n",
        "        model_ewc, optimizer_ewc, nn.CrossEntropyLoss(),\n",
        "        ewc_lambda=lmbda,\n",
        "        train_mb_size=CONFIG['batch_size'],\n",
        "        train_epochs=CONFIG['epochs_per_task'],\n",
        "        device=device\n",
        "    )\n",
        "\n",
        "    # List to store accuracies for this specific lambda run\n",
        "    current_lambda_accuracies = []\n",
        "\n",
        "    # --- Training and Evaluation Loop ---\n",
        "    for task_id, experience in enumerate(split_mnist_benchmark.train_stream):\n",
        "        print(f\"--> Training on experience {task_id+1}\")\n",
        "\n",
        "        ewc_strategy.train(experience)\n",
        "\n",
        "        accuracy = evaluate_on_seen_tasks(\n",
        "            ewc_strategy.model,\n",
        "            split_mnist_benchmark,\n",
        "            task_id,\n",
        "            device,\n",
        "            CONFIG['batch_size']\n",
        "        )\n",
        "        current_lambda_accuracies.append(accuracy)\n",
        "        print(f\"----- Avg Accuracy after Task {task_id+1}: {accuracy:.2f}% -----\")\n",
        "\n",
        "    # Store the results for this lambda\n",
        "    all_ewc_results[lmbda] = current_lambda_accuracies\n",
        "\n",
        "# --- Find the best EWC result and save it ---\n",
        "best_lambda_ewc = max(all_ewc_results, key=lambda k: all_ewc_results[k][-1])\n",
        "CONFIG['best_ewc_lambda'] = best_lambda_ewc\n",
        "best_ewc_accuracies = all_ewc_results[best_lambda_ewc]\n",
        "\n",
        "print(f\"\\nBest EWC lambda was {best_lambda_ewc} with final accuracy: {best_ewc_accuracies[-1]:.2f}%\")\n",
        "\n",
        "# --- SAVE AND PLOT ---\n",
        "ewc_results_path = os.path.join(CONFIG['results_path'], 'ewc_accuracies.pkl')\n",
        "save_results(best_ewc_accuracies, ewc_results_path)\n",
        "print(f\"\\nEWC results saved to {ewc_results_path}\")\n",
        "\n",
        "baseline_accuracies = load_results(os.path.join(CONFIG['results_path'], 'baseline_accuracies.pkl'))\n",
        "er_accuracies = load_results(os.path.join(CONFIG['results_path'], 'er_accuracies.pkl'))\n",
        "tfc_accuracies = load_results(os.path.join(CONFIG['results_path'], 'tfc_sr_accuracies.pkl'))\n",
        "\n",
        "results_to_plot = {\n",
        "    'Baseline': baseline_accuracies,\n",
        "    'Standard ER': er_accuracies,\n",
        "    f'EWC (λ={best_lambda_ewc})': best_ewc_accuracies,\n",
        "    'TFC-SR': tfc_accuracies\n",
        "}\n",
        "plot_results(results_to_plot, title=\"All Methods vs. Baselines on Split MNIST\")"
      ],
      "metadata": {
        "id": "rxuXPwK8Faud"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# --- EXPERIMENT: SYNAPTIC INTELLIGENCE (SI) ---\n",
        "\n",
        "print(\"\\n\" + \"=\"*20 + \" Starting SI Experiment \" + \"=\"*20)\n",
        "\n",
        "# --- Hyperparameter Search for SI ---\n",
        "si_lambdas_to_try = [0.1, 1.0, 10.0, 100.0]\n",
        "all_si_results = {}\n",
        "\n",
        "for lmbda in si_lambdas_to_try:\n",
        "    print(f\"\\n--- Running SI with lambda = {lmbda} ---\")\n",
        "\n",
        "    set_seed(CONFIG['seed'])\n",
        "\n",
        "    # --- Setup strategy for this trial ---\n",
        "    model_si = CNN(num_classes=CONFIG['num_classes']).to(device)\n",
        "    optimizer_si = optim.Adam(model_si.parameters(), lr=CONFIG['lr'])\n",
        "\n",
        "    # Instantiate the SynapticIntelligence strategy\n",
        "    si_strategy = SynapticIntelligence(\n",
        "        model_si, optimizer_si, nn.CrossEntropyLoss(),\n",
        "        si_lambda=lmbda,\n",
        "        train_mb_size=CONFIG['batch_size'],\n",
        "        train_epochs=CONFIG['epochs_per_task'],\n",
        "        device=device\n",
        "    )\n",
        "\n",
        "    current_lambda_accuracies = []\n",
        "\n",
        "    # --- Training and Evaluation Loop ---\n",
        "    for task_id, experience in enumerate(split_mnist_benchmark.train_stream):\n",
        "        print(f\"--> Training on experience {task_id+1}\")\n",
        "\n",
        "        si_strategy.train(experience)\n",
        "\n",
        "        # --- EVALUATION STEP ---\n",
        "        accuracy = evaluate_on_seen_tasks(\n",
        "            si_strategy.model,\n",
        "            split_mnist_benchmark,\n",
        "            task_id,\n",
        "            device,\n",
        "            CONFIG['batch_size']\n",
        "        )\n",
        "        current_lambda_accuracies.append(accuracy)\n",
        "        print(f\"----- Avg Accuracy after Task {task_id+1}: {accuracy:.2f}% -----\")\n",
        "\n",
        "    # Store the results for this lambda\n",
        "    all_si_results[lmbda] = current_lambda_accuracies\n",
        "\n",
        "# --- Find the best SI result and save it ---\n",
        "best_lambda_si = max(all_si_results, key=lambda k: all_si_results[k][-1])\n",
        "best_si_accuracies = all_si_results[best_lambda_si]\n",
        "\n",
        "print(f\"\\nBest SI lambda was {best_lambda_si} with final accuracy: {best_si_accuracies[-1]:.2f}%\")\n",
        "\n",
        "si_results_path = os.path.join(CONFIG['results_path'], 'si_accuracies.pkl')\n",
        "save_results(best_si_accuracies, si_results_path)\n",
        "\n",
        "\n",
        "# --- PLOT ALL RESULTS TOGETHER ---\n",
        "# (Load all previous results and plot everything)\n",
        "baseline_accuracies = load_results(os.path.join(CONFIG['results_path'], 'baseline_accuracies.pkl'))\n",
        "er_accuracies = load_results(os.path.join(CONFIG['results_path'], 'er_accuracies.pkl'))\n",
        "ewc_accuracies = load_results(os.path.join(CONFIG['results_path'], 'ewc_accuracies.pkl'))\n",
        "tfc_accuracies = load_results(os.path.join(CONFIG['results_path'], 'tfc_sr_accuracies.pkl'))\n",
        "best_ewc_lambda = CONFIG['best_ewc_lambda']\n",
        "\n",
        "results_to_plot = {\n",
        "    'Baseline': baseline_accuracies,\n",
        "    'Standard ER': er_accuracies,\n",
        "    f'EWC (Best λ={best_ewc_lambda})': ewc_accuracies,\n",
        "    f'SI (Best λ={best_lambda_si})': best_si_accuracies,\n",
        "    'TFC-SR': tfc_accuracies\n",
        "}\n",
        "plot_results(results_to_plot, title=\"All Methods vs. Baselines on Split MNIST\")"
      ],
      "metadata": {
        "id": "vYrs7JlXRkfG"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}